rss/translation_worker.py
2025-10-09 21:53:56 +02:00

472 lines
17 KiB
Python

import os
import time
import logging
import contextlib
from typing import List, Optional
import psycopg2
import psycopg2.extras
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from langdetect import detect, DetectorFactory
DetectorFactory.seed = 0 # resultados reproducibles
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
LOG = logging.getLogger(__name__)
# ---------- Config DB ----------
DB_CONFIG = {
"host": os.environ.get("DB_HOST", "localhost"),
"port": int(os.environ.get("DB_PORT", 5432)),
"dbname": os.environ.get("DB_NAME", "rss"),
"user": os.environ.get("DB_USER", "rss"),
"password": os.environ.get("DB_PASS", "x"),
}
# ---------- Helpers ENV (con retrocompatibilidad) ----------
def _env_list(name: str, *fallbacks: str, default: str = "es") -> List[str]:
raw = None
for key in (name, *fallbacks):
raw = os.environ.get(key)
if raw:
break
raw = raw if raw is not None else default
return [s.strip() for s in raw.split(",") if s and s.strip()]
def _env_int(name: str, *fallbacks: str, default: int = 8) -> int:
for key in (name, *fallbacks):
val = os.environ.get(key)
if val:
try:
return int(val)
except ValueError:
pass
return default
def _env_float(name: str, *fallbacks: str, default: float = 5.0) -> float:
for key in (name, *fallbacks):
val = os.environ.get(key)
if val:
try:
return float(val)
except ValueError:
pass
return default
def _env_str(name: str, *fallbacks: str, default: Optional[str] = None) -> Optional[str]:
for key in (name, *fallbacks):
val = os.environ.get(key)
if val:
return val
return default
TARGET_LANGS = _env_list("TARGET_LANGS", "TRANSLATE_TO", default="es")
BATCH_SIZE = _env_int("BATCH", "TRANSLATOR_BATCH", "TRANSLATE_BATCH", default=8)
ENQUEUE_MAX = _env_int("ENQUEUE", "TRANSLATOR_ENQUEUE", "TRANSLATE_ENQUEUE", default=200)
SLEEP_IDLE = _env_float("SLEEP_IDLE", "TRANSLATOR_SLEEP_IDLE", "TRANSLATE_SLEEP_IDLE", default=5.0)
DEVICE_CFG = (_env_str("DEVICE", default="auto") or "auto").lower() # 'cpu' | 'cuda' | 'auto'
# Límites de tokens (ajusta si ves OOM)
MAX_SRC_TOKENS = _env_int("MAX_SRC_TOKENS", default=384)
MAX_NEW_TOKENS = _env_int("MAX_NEW_TOKENS", default=192)
# ---- Beams: por defecto 2 para títulos y 1 para cuerpo; respeta NUM_BEAMS si sólo se define ese ----
def _beams_from_env():
nb_global = os.environ.get("NUM_BEAMS")
has_title = os.environ.get("NUM_BEAMS_TITLE") is not None
has_body = os.environ.get("NUM_BEAMS_BODY") is not None
if nb_global and not has_title and not has_body:
try:
v = max(1, int(nb_global))
return v, v
except ValueError:
pass
# por defecto: 2 (título), 1 (cuerpo)
return _env_int("NUM_BEAMS_TITLE", default=2), _env_int("NUM_BEAMS_BODY", default=1)
NUM_BEAMS_TITLE, NUM_BEAMS_BODY = _beams_from_env()
# Modelo por defecto: NLLB 600M (cámbialo por facebook/nllb-200-1.3B si quieres el 1.3B)
UNIVERSAL_MODEL = _env_str("UNIVERSAL_MODEL", default="facebook/nllb-200-distilled-600M")
# ---------- Mapeo idiomas a códigos NLLB ----------
NLLB_LANG = {
# básicos
"es": "spa_Latn", "en": "eng_Latn", "fr": "fra_Latn", "de": "deu_Latn", "it": "ita_Latn", "pt": "por_Latn",
# nórdicos
"nl": "nld_Latn", "sv": "swe_Latn", "da": "dan_Latn", "fi": "fin_Latn",
# noruego
"no": "nob_Latn", "nb": "nob_Latn", "nn": "nno_Latn",
# CEE
"pl": "pol_Latn", "cs": "ces_Latn", "sk": "slk_Latn", "sl": "slv_Latn",
"hu": "hun_Latn", "ro": "ron_Latn", "bg": "bul_Cyrl", "el": "ell_Grek",
"ru": "rus_Cyrl", "uk": "ukr_Cyrl", "hr": "hrv_Latn", "sr": "srp_Cyrl", "bs": "bos_Latn",
# ME/Asia
"tr": "tur_Latn", "ar": "arb_Arab", "fa": "pes_Arab", "he": "heb_Hebr",
"zh": "zho_Hans", "ja": "jpn_Jpan", "ko": "kor_Hang",
# SEA
"vi": "vie_Latn", "th": "tha_Thai", "id": "ind_Latn", "ms": "zsm_Latn",
# variantes
"pt-br": "por_Latn", "pt-pt": "por_Latn",
}
def map_to_nllb(code: Optional[str]) -> Optional[str]:
if not code:
return None
code = code.strip().lower()
if code in NLLB_LANG:
return NLLB_LANG[code]
return f"{code}_Latn"
def normalize_lang(code: Optional[str], default: Optional[str] = None) -> Optional[str]:
if not code:
return default
code = code.strip().lower()
return code if code else default
# ---------- DB ----------
def get_conn():
return psycopg2.connect(**DB_CONFIG)
def ensure_indexes(conn):
with conn.cursor() as cur:
cur.execute("""
CREATE INDEX IF NOT EXISTS traducciones_lang_to_status_idx
ON traducciones (lang_to, status);
CREATE INDEX IF NOT EXISTS traducciones_status_idx
ON traducciones (status);
""")
conn.commit()
def ensure_pending(conn, lang_to: str, enqueue_limit: int):
with conn.cursor() as cur:
cur.execute("""
INSERT INTO traducciones (noticia_id, lang_from, lang_to, status)
SELECT sub.id, NULL, %s, 'pending'
FROM (
SELECT n.id
FROM noticias n
LEFT JOIN traducciones t
ON t.noticia_id = n.id AND t.lang_to = %s
WHERE t.id IS NULL
ORDER BY n.fecha DESC NULLS LAST, n.id
LIMIT %s
) AS sub;
""", (lang_to, lang_to, enqueue_limit))
conn.commit()
def fetch_pending_batch(conn, lang_to: str, batch_size: int):
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute("""
SELECT t.id AS tr_id, t.noticia_id, t.lang_from, t.lang_to,
n.titulo, n.resumen
FROM traducciones t
JOIN noticias n ON n.id = t.noticia_id
WHERE t.lang_to = %s AND t.status = 'pending'
ORDER BY t.id
LIMIT %s;
""", (lang_to, batch_size))
rows = cur.fetchall()
if rows:
ids = [r["tr_id"] for r in rows]
with conn.cursor() as cur:
cur.execute("UPDATE traducciones SET status='processing' WHERE id = ANY(%s)", (ids,))
conn.commit()
return rows
def mark_done(conn, tr_id: int, title_tr: str, body_tr: str, lang_from: Optional[str]):
with conn.cursor() as cur:
cur.execute("""
UPDATE traducciones
SET titulo_trad=%s, resumen_trad=%s,
lang_from = COALESCE(lang_from, %s),
status='done', error=NULL
WHERE id=%s;
""", (title_tr, body_tr, lang_from, tr_id))
conn.commit()
def mark_error(conn, tr_id: int, msg: str):
with conn.cursor() as cur:
cur.execute("UPDATE traducciones SET status='error', error=%s WHERE id=%s;", (msg[:1500], tr_id))
conn.commit()
def detect_lang(text1: str, text2: str) -> Optional[str]:
txt = (text1 or "").strip() or (text2 or "").strip()
if not txt:
return None
try:
return detect(txt)
except Exception:
return None
# ---------- Modelo único y manejo de CUDA (NLLB) ----------
_TOKENIZER: Optional[AutoTokenizer] = None
_MODEL: Optional[AutoModelForSeq2SeqLM] = None
_DEVICE: Optional[torch.device] = None
_CUDA_FAILS: int = 0
_CUDA_DISABLED: bool = False
def _resolve_device() -> torch.device:
global _CUDA_DISABLED
if _CUDA_DISABLED:
return torch.device("cpu")
if DEVICE_CFG == "cpu":
return torch.device("cpu")
if DEVICE_CFG == "cuda":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
# auto
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _is_cuda_mem_error(exc: Exception) -> bool:
s = str(exc)
return ("CUDA out of memory" in s) or ("CUDACachingAllocator" in s) or ("expandable_segment" in s)
def _free_cuda():
if torch.cuda.is_available():
try:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except Exception:
pass
def _load_model_on(device: torch.device):
"""Carga (o recarga) el modelo/tokenizer en el dispositivo indicado."""
global _TOKENIZER, _MODEL, _DEVICE
dtype = torch.float16 if device.type == "cuda" else torch.float32
LOG.info("Cargando modelo universal %s (device=%s, dtype=%s)", UNIVERSAL_MODEL, device, dtype)
tok = AutoTokenizer.from_pretrained(UNIVERSAL_MODEL)
mdl = AutoModelForSeq2SeqLM.from_pretrained(
UNIVERSAL_MODEL,
torch_dtype=dtype,
low_cpu_mem_usage=True
)
# use_cache=False reduce picos de VRAM en generación
try:
mdl.config.use_cache = False
except Exception:
pass
mdl.to(device)
mdl.eval()
_TOKENIZER, _MODEL, _DEVICE = tok, mdl, device
def get_universal_components():
"""Devuelve (tokenizer, model, device). Carga en GPU si está disponible y estable."""
global _TOKENIZER, _MODEL, _DEVICE, _CUDA_FAILS, _CUDA_DISABLED
if _MODEL is not None and _DEVICE is not None:
return _TOKENIZER, _MODEL, _DEVICE
dev = _resolve_device()
try:
_load_model_on(dev)
return _TOKENIZER, _MODEL, _DEVICE
except Exception as e:
LOG.warning("Fallo cargando modelo en %s: %s", dev, e)
if dev.type == "cuda" and _is_cuda_mem_error(e):
_CUDA_FAILS += 1
_CUDA_DISABLED = True
_free_cuda()
LOG.warning("Deshabilitando CUDA y reintentando en CPU (fallos CUDA=%d)", _CUDA_FAILS)
_load_model_on(torch.device("cpu"))
return _TOKENIZER, _MODEL, _DEVICE
_load_model_on(torch.device("cpu"))
return _TOKENIZER, _MODEL, _DEVICE
# ---------- Utilidades ----------
def _token_chunks(tokenizer, text: str, max_tokens: int) -> List[str]:
if not text:
return []
ids = tokenizer.encode(text, add_special_tokens=False)
if len(ids) <= max_tokens:
return [text]
chunks = []
for i in range(0, len(ids), max_tokens):
sub = ids[i:i+max_tokens]
piece = tokenizer.decode(sub, skip_special_tokens=True, clean_up_tokenization_spaces=True)
if piece.strip():
chunks.append(piece.strip())
return chunks
def _norm(s: str) -> str:
import re
return re.sub(r"\W+", "", (s or "").lower()).strip()
def _forced_bos_id(tokenizer: AutoTokenizer, model: AutoModelForSeq2SeqLM, tgt_code: str) -> int:
"""
Resuelve el id del token de idioma destino para NLLB de forma robusta,
funcionando aunque falte `lang_code_to_id` en el tokenizer.
"""
# 1) tokenizer.lang_code_to_id (si existe)
try:
mapping = getattr(tokenizer, "lang_code_to_id", None)
if isinstance(mapping, dict):
tid = mapping.get(tgt_code)
if isinstance(tid, int):
return tid
except Exception:
pass
# 2) model.config.lang_code_to_id (si existe)
try:
mapping = getattr(getattr(model, "config", None), "lang_code_to_id", None)
if isinstance(mapping, dict):
tid = mapping.get(tgt_code)
if isinstance(tid, int):
return tid
except Exception:
pass
# 3) convert_tokens_to_ids (algunos builds registran el código como token especial)
try:
tid = tokenizer.convert_tokens_to_ids(tgt_code)
if isinstance(tid, int) and tid not in (-1, getattr(tokenizer, "unk_token_id", -1)):
return tid
except Exception:
pass
# 4) additional_special_tokens/_ids (buscar el código tal cual)
try:
ats = getattr(tokenizer, "additional_special_tokens", None)
ats_ids = getattr(tokenizer, "additional_special_tokens_ids", None)
if isinstance(ats, list) and isinstance(ats_ids, list) and tgt_code in ats:
idx = ats.index(tgt_code)
if 0 <= idx < len(ats_ids) and isinstance(ats_ids[idx], int):
return ats_ids[idx]
except Exception:
pass
# 5) último recurso: usa eos/bos para no romper generate()
LOG.warning("No pude resolver lang code id para '%s'. Uso fallback (eos/bos).", tgt_code)
return getattr(tokenizer, "eos_token_id", None) or getattr(tokenizer, "bos_token_id", None) or 0
@torch.inference_mode()
def translate_text(src_lang: str, tgt_lang: str, text: str, num_beams: int = 1, _tries: int = 0) -> str:
if not text or not text.strip():
return ""
tok, mdl, device = get_universal_components()
src_code = map_to_nllb(src_lang) or "eng_Latn"
tgt_code = map_to_nllb(tgt_lang) or "spa_Latn"
# Configura idioma origen (si la prop existe)
try:
tok.src_lang = src_code
except Exception:
pass
forced_bos = _forced_bos_id(tok, mdl, tgt_code)
parts = _token_chunks(tok, text, MAX_SRC_TOKENS)
outs: List[str] = []
try:
autocast_ctx = torch.amp.autocast("cuda", dtype=torch.float16) if device.type == "cuda" else contextlib.nullcontext()
for p in parts:
enc = tok(p, return_tensors="pt", truncation=True, max_length=MAX_SRC_TOKENS)
enc = {k: v.to(device) for k, v in enc.items()}
gen_kwargs = dict(
forced_bos_token_id=forced_bos,
max_new_tokens=MAX_NEW_TOKENS,
num_beams=max(1, int(num_beams)),
do_sample=False,
use_cache=False, # ↓ memoria
)
# Evita el warning cuando num_beams = 1
if int(num_beams) > 1:
gen_kwargs["early_stopping"] = True
with autocast_ctx:
generated = mdl.generate(**enc, **gen_kwargs)
out = tok.batch_decode(generated, skip_special_tokens=True)[0].strip()
outs.append(out)
del enc, generated
if device.type == "cuda":
_free_cuda()
return "\n".join([o for o in outs if o]).strip()
except Exception as e:
if device.type == "cuda" and _is_cuda_mem_error(e) and _tries < 2:
LOG.warning("CUDA OOM/allocator: intento de recuperación %d. Detalle: %s", _tries + 1, e)
# desactiva CUDA y relanza en CPU
global _MODEL, _DEVICE, _CUDA_DISABLED
_CUDA_DISABLED = True
try:
if _MODEL is not None:
del _MODEL
except Exception:
pass
_free_cuda()
_MODEL = None
_DEVICE = None
time.sleep(1.0)
return translate_text(src_lang, tgt_lang, text, num_beams=num_beams, _tries=_tries + 1)
raise
def process_batch(conn, rows):
for r in rows:
tr_id = r["tr_id"]
lang_to = normalize_lang(r["lang_to"], "es") or "es"
lang_from = normalize_lang(r["lang_from"]) or detect_lang(r["titulo"] or "", r["resumen"] or "") or "en"
title = (r["titulo"] or "").strip()
body = (r["resumen"] or "").strip()
# Si ya está en el mismo idioma, copia tal cual
if (map_to_nllb(lang_from) or "eng_Latn") == (map_to_nllb(lang_to) or "spa_Latn"):
mark_done(conn, tr_id, title, body, lang_from)
continue
try:
# Beams distintos: mejor calidad en títulos con coste de VRAM controlado
title_tr = translate_text(lang_from, lang_to, title, num_beams=NUM_BEAMS_TITLE) if title else ""
body_tr = translate_text(lang_from, lang_to, body, num_beams=NUM_BEAMS_BODY) if body else ""
# Si la "traducción" es igual al original, déjala vacía
if _norm(title_tr) == _norm(title):
title_tr = ""
if _norm(body_tr) == _norm(body):
body_tr = ""
mark_done(conn, tr_id, title_tr, body_tr, lang_from)
except Exception as e:
LOG.exception("Error traduciendo fila")
mark_error(conn, tr_id, str(e))
def main():
LOG.info(
"Arrancando worker de traducción (NLLB). TARGET_LANGS=%s, BATCH=%s, ENQUEUE=%s, DEVICE=%s, BEAMS(title/body)=%s/%s",
TARGET_LANGS, BATCH_SIZE, ENQUEUE_MAX, DEVICE_CFG, NUM_BEAMS_TITLE, NUM_BEAMS_BODY
)
# Pre-carga el modelo una vez para reservar memoria de forma limpia
get_universal_components()
while True:
any_work = False
with get_conn() as conn:
ensure_indexes(conn)
for lt in TARGET_LANGS:
lt = normalize_lang(lt, "es") or "es"
ensure_pending(conn, lt, ENQUEUE_MAX)
while True:
rows = fetch_pending_batch(conn, lt, BATCH_SIZE)
if not rows:
break
any_work = True
LOG.info("[%s] Procesando %d elementos…", lt, len(rows))
process_batch(conn, rows)
if not any_work:
time.sleep(SLEEP_IDLE)
if __name__ == "__main__":
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
main()