rss/translation_worker.py
2025-10-12 17:51:14 +02:00

634 lines
23 KiB
Python

# translation_worker.py
import os
import time
import logging
import contextlib
import re
from typing import List, Optional
import psycopg2
import psycopg2.extras
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from langdetect import detect, DetectorFactory
DetectorFactory.seed = 0 # resultados reproducibles
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
LOG = logging.getLogger(__name__)
# ---------- Config DB ----------
DB_CONFIG = {
"host": os.environ.get("DB_HOST", "localhost"),
"port": int(os.environ.get("DB_PORT", 5432)),
"dbname": os.environ.get("DB_NAME", "rss"),
"user": os.environ.get("DB_USER", "rss"),
"password": os.environ.get("DB_PASS", "x"),
}
# ---------- Helpers ENV (con retrocompatibilidad) ----------
def _env_list(name: str, *fallbacks: str, default: str = "es") -> List[str]:
raw = None
for key in (name, *fallbacks):
raw = os.environ.get(key)
if raw:
break
raw = raw if raw is not None else default
return [s.strip() for s in raw.split(",") if s and s.strip()]
def _env_int(name: str, *fallbacks: str, default: int = 8) -> int:
for key in (name, *fallbacks):
val = os.environ.get(key)
if val:
try:
return int(val)
except ValueError:
pass
return default
def _env_float(name: str, *fallbacks: str, default: float = 5.0) -> float:
for key in (name, *fallbacks):
val = os.environ.get(key)
if val:
try:
return float(val)
except ValueError:
pass
return default
def _env_str(name: str, *fallbacks: str, default: Optional[str] = None) -> Optional[str]:
for key in (name, *fallbacks):
val = os.environ.get(key)
if val:
return val
return default
def _env_bool(name: str, default: bool = False) -> bool:
val = os.environ.get(name)
if val is None:
return default
return str(val).strip().lower() in ("1", "true", "yes", "y", "on")
TARGET_LANGS = _env_list("TARGET_LANGS", "TRANSLATE_TO", default="es")
BATCH_SIZE = _env_int("BATCH", "TRANSLATOR_BATCH", "TRANSLATE_BATCH", default=8)
ENQUEUE_MAX = _env_int("ENQUEUE", "TRANSLATOR_ENQUEUE", "TRANSLATE_ENQUEUE", default=200)
SLEEP_IDLE = _env_float("SLEEP_IDLE", "TRANSLATOR_SLEEP_IDLE", "TRANSLATE_SLEEP_IDLE", default=5.0)
DEVICE_CFG = (_env_str("DEVICE", default="auto") or "auto").lower() # 'cpu' | 'cuda' | 'auto'
# Límites de tokens (ajusta si ves OOM)
MAX_SRC_TOKENS = _env_int("MAX_SRC_TOKENS", default=512)
MAX_NEW_TOKENS = _env_int("MAX_NEW_TOKENS", default=256)
# ---- Beams: por defecto 2 para títulos y 1 para cuerpo; respeta NUM_BEAMS si sólo se define ese ----
def _beams_from_env():
nb_global = os.environ.get("NUM_BEAMS")
has_title = os.environ.get("NUM_BEAMS_TITLE") is not None
has_body = os.environ.get("NUM_BEAMS_BODY") is not None
if nb_global and not has_title and not has_body:
try:
v = max(1, int(nb_global))
return v, v
except ValueError:
pass
# por defecto: 2 (título), 1 (cuerpo)
return _env_int("NUM_BEAMS_TITLE", default=2), _env_int("NUM_BEAMS_BODY", default=1)
NUM_BEAMS_TITLE, NUM_BEAMS_BODY = _beams_from_env()
# Modelo por defecto: NLLB 600M (cámbialo por facebook/nllb-200-1.3B si quieres el 1.3B)
UNIVERSAL_MODEL = _env_str("UNIVERSAL_MODEL", default="facebook/nllb-200-distilled-600M")
# ---------- Chunking por frases (para artículos largos) ----------
# Activo por defecto para evitar secuencias > límite del modelo
CHUNK_BY_SENTENCES = _env_bool("CHUNK_BY_SENTENCES", default=True)
CHUNK_MAX_TOKENS = _env_int("CHUNK_MAX_TOKENS", default=900) # <= modelo - margen
CHUNK_OVERLAP_SENTS = _env_int("CHUNK_OVERLAP_SENTS", default=0) # 0 o 1
# Abreviaturas comunes y marcador temporal
_ABBR = ("Sr", "Sra", "Dr", "Dra", "Ing", "Lic", "pág", "etc")
_ABBR_MARK = "§" # no debería aparecer en texto normal
def _protect_abbrev(text: str) -> str:
# Iniciales de una letra: "E.", "A."
t = re.sub(r"\b([A-ZÁÉÍÓÚÑÄÖÜ])\.", r"\1" + _ABBR_MARK, text)
# Abreviaturas de la lista (case-insensitive)
pat = r"\b(?:" + "|".join(map(re.escape, _ABBR)) + r")\."
t = re.sub(pat, lambda m: m.group(0)[:-1] + _ABBR_MARK, t, flags=re.IGNORECASE)
return t
def _restore_abbrev(text: str) -> str:
return text.replace(_ABBR_MARK, ".")
# Regex de corte SIN look-behind variable:
# - Corta tras [.!?…] si hay espacios y luego comienza otra frase (letra mayúscula, comillas, paréntesis, dígito)
# - O cuando hay doble salto de línea
_SENT_SPLIT_RE = re.compile(
r'(?<=[\.!\?…])\s+(?=["\(\[A-ZÁÉÍÓÚÑÄÖÜ0-9])|(?:\n{2,})'
)
def split_into_sentences(text: str) -> List[str]:
text = (text or "").strip()
if not text:
return []
protected = _protect_abbrev(text)
parts = [p.strip() for p in _SENT_SPLIT_RE.split(protected) if p and p.strip()]
parts = [_restore_abbrev(p) for p in parts]
# Une piezas muy cortas con la anterior para más coherencia
merged: List[str] = []
for p in parts:
if merged and len(p) < 40:
merged[-1] = merged[-1] + " " + p
else:
merged.append(p)
return merged
# ---------- Mapeo idiomas a códigos NLLB ----------
NLLB_LANG = {
# básicos
"es": "spa_Latn", "en": "eng_Latn", "fr": "fra_Latn", "de": "deu_Latn", "it": "ita_Latn", "pt": "por_Latn",
# nórdicos
"nl": "nld_Latn", "sv": "swe_Latn", "da": "dan_Latn", "fi": "fin_Latn",
# noruego
"no": "nob_Latn", "nb": "nob_Latn", "nn": "nno_Latn",
# CEE
"pl": "pol_Latn", "cs": "ces_Latn", "sk": "slk_Latn", "sl": "slv_Latn",
"hu": "hun_Latn", "ro": "ron_Latn", "bg": "bul_Cyrl", "el": "ell_Grek",
"ru": "rus_Cyrl", "uk": "ukr_Cyrl", "hr": "hrv_Latn", "sr": "srp_Cyrl", "bs": "bos_Latn",
# ME/Asia
"tr": "tur_Latn", "ar": "arb_Arab", "fa": "pes_Arab", "he": "heb_Hebr",
"zh": "zho_Hans", "ja": "jpn_Jpan", "ko": "kor_Hang",
# SEA
"vi": "vie_Latn", "th": "tha_Thai", "id": "ind_Latn", "ms": "zsm_Latn",
# variantes
"pt-br": "por_Latn", "pt-pt": "por_Latn",
}
def map_to_nllb(code: Optional[str]) -> Optional[str]:
if not code:
return None
code = code.strip().lower()
if code in NLLB_LANG:
return NLLB_LANG[code]
return f"{code}_Latn"
def normalize_lang(code: Optional[str], default: Optional[str] = None) -> Optional[str]:
if not code:
return default
code = code.strip().lower()
return code if code else default
# ---------- DB ----------
def get_conn():
return psycopg2.connect(**DB_CONFIG)
def ensure_indexes(conn):
with conn.cursor() as cur:
cur.execute("""
CREATE INDEX IF NOT EXISTS traducciones_lang_to_status_idx
ON traducciones (lang_to, status);
CREATE INDEX IF NOT EXISTS traducciones_status_idx
ON traducciones (status);
""")
conn.commit()
def ensure_pending(conn, lang_to: str, enqueue_limit: int):
with conn.cursor() as cur:
cur.execute("""
INSERT INTO traducciones (noticia_id, lang_from, lang_to, status)
SELECT sub.id, NULL, %s, 'pending'
FROM (
SELECT n.id
FROM noticias n
LEFT JOIN traducciones t
ON t.noticia_id = n.id AND t.lang_to = %s
WHERE t.id IS NULL
ORDER BY n.fecha DESC NULLS LAST, n.id
LIMIT %s
) AS sub;
""", (lang_to, lang_to, enqueue_limit))
conn.commit()
def fetch_pending_batch(conn, lang_to: str, batch_size: int):
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute("""
SELECT t.id AS tr_id, t.noticia_id, t.lang_from, t.lang_to,
n.titulo, n.resumen
FROM traducciones t
JOIN noticias n ON n.id = t.noticia_id
WHERE t.lang_to = %s AND t.status = 'pending'
ORDER BY t.id
LIMIT %s;
""", (lang_to, batch_size))
rows = cur.fetchall()
if rows:
ids = [r["tr_id"] for r in rows]
with conn.cursor() as cur2:
cur2.execute("UPDATE traducciones SET status='processing' WHERE id = ANY(%s)", (ids,))
conn.commit()
return rows
def mark_done(conn, tr_id: int, title_tr: str, body_tr: str, lang_from: Optional[str]):
with conn.cursor() as cur:
cur.execute("""
UPDATE traducciones
SET titulo_trad=%s, resumen_trad=%s,
lang_from = COALESCE(lang_from, %s),
status='done', error=NULL
WHERE id=%s;
""", (title_tr, body_tr, lang_from, tr_id))
conn.commit()
def mark_error(conn, tr_id: int, msg: str):
with conn.cursor() as cur:
cur.execute("UPDATE traducciones SET status='error', error=%s WHERE id=%s;", (msg[:1500], tr_id))
conn.commit()
def detect_lang(text1: str, text2: str) -> Optional[str]:
txt = (text1 or "").strip() or (text2 or "").strip()
if not txt:
return None
try:
return detect(txt)
except Exception:
return None
# ---------- Modelo único y manejo de CUDA (NLLB) ----------
_TOKENIZER: Optional[AutoTokenizer] = None
_MODEL: Optional[AutoModelForSeq2SeqLM] = None
_DEVICE: Optional[torch.device] = None
_CUDA_FAILS: int = 0
_CUDA_DISABLED: bool = False
def _resolve_device() -> torch.device:
global _CUDA_DISABLED
if _CUDA_DISABLED:
return torch.device("cpu")
if DEVICE_CFG == "cpu":
return torch.device("cpu")
if DEVICE_CFG == "cuda":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
# auto
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _is_cuda_mem_error(exc: Exception) -> bool:
s = str(exc)
return ("CUDA out of memory" in s) or ("CUDACachingAllocator" in s) or ("expandable_segment" in s)
def _free_cuda():
if torch.cuda.is_available():
try:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except Exception:
pass
def _load_model_on(device: torch.device):
"""Carga (o recarga) el modelo/tokenizer en el dispositivo indicado."""
global _TOKENIZER, _MODEL, _DEVICE
dtype = torch.float16 if device.type == "cuda" else torch.float32
LOG.info("Cargando modelo universal %s (device=%s, dtype=%s)", UNIVERSAL_MODEL, device, dtype)
tok = AutoTokenizer.from_pretrained(UNIVERSAL_MODEL)
mdl = AutoModelForSeq2SeqLM.from_pretrained(
UNIVERSAL_MODEL,
torch_dtype=dtype,
low_cpu_mem_usage=True
)
# use_cache=False reduce picos de VRAM en generación
try:
mdl.config.use_cache = False
except Exception:
pass
mdl.to(device)
mdl.eval()
_TOKENIZER, _MODEL, _DEVICE = tok, mdl, device
def get_universal_components():
"""Devuelve (tokenizer, model, device). Carga en GPU si está disponible y estable."""
global _TOKENIZER, _MODEL, _DEVICE, _CUDA_FAILS, _CUDA_DISABLED
if _MODEL is not None and _DEVICE is not None:
return _TOKENIZER, _MODEL, _DEVICE
dev = _resolve_device()
try:
_load_model_on(dev)
return _TOKENIZER, _MODEL, _DEVICE
except Exception as e:
LOG.warning("Fallo cargando modelo en %s: %s", dev, e)
if dev.type == "cuda" and _is_cuda_mem_error(e):
_CUDA_FAILS += 1
_CUDA_DISABLED = True
_free_cuda()
LOG.warning("Deshabilitando CUDA y reintentando en CPU (fallos CUDA=%d)", _CUDA_FAILS)
_load_model_on(torch.device("cpu"))
return _TOKENIZER, _MODEL, _DEVICE
_load_model_on(torch.device("cpu"))
return _TOKENIZER, _MODEL, _DEVICE
# ---------- Utilidades de tokenización / chunking ----------
def _safe_src_len(tokenizer) -> int:
model_max = getattr(tokenizer, "model_max_length", 1024) or 1024
# margen para tokens especiales/ruido
return min(MAX_SRC_TOKENS, int(model_max) - 16)
def _token_chunks(tokenizer, text: str, max_tokens: int) -> List[str]:
"""Troceo simple por tokens (fallback)"""
if not text:
return []
ids = tokenizer.encode(text, add_special_tokens=False)
if len(ids) <= max_tokens:
return [text]
chunks = []
for i in range(0, len(ids), max_tokens):
sub = ids[i:i+max_tokens]
piece = tokenizer.decode(sub, skip_special_tokens=True, clean_up_tokenization_spaces=True)
if piece.strip():
chunks.append(piece.strip())
return chunks
def _norm(s: str) -> str:
import re as _re
return _re.sub(r"\W+", "", (s or "").lower()).strip()
def _forced_bos_id(tokenizer: AutoTokenizer, model: AutoModelForSeq2SeqLM, tgt_code: str) -> int:
"""
Resuelve el id del token de idioma destino para NLLB de forma robusta,
funcionando aunque falte `lang_code_to_id` en el tokenizer.
"""
# 1) tokenizer.lang_code_to_id (si existe)
try:
mapping = getattr(tokenizer, "lang_code_to_id", None)
if isinstance(mapping, dict):
tid = mapping.get(tgt_code)
if isinstance(tid, int):
return tid
except Exception:
pass
# 2) model.config.lang_code_to_id (si existe)
try:
mapping = getattr(getattr(model, "config", None), "lang_code_to_id", None)
if isinstance(mapping, dict):
tid = mapping.get(tgt_code)
if isinstance(tid, int):
return tid
except Exception:
pass
# 3) convert_tokens_to_ids (algunos builds registran el código como token especial)
try:
tid = tokenizer.convert_tokens_to_ids(tgt_code)
if isinstance(tid, int) and tid not in (-1, getattr(tokenizer, "unk_token_id", -1)):
return tid
except Exception:
pass
# 4) additional_special_tokens/_ids (buscar el código tal cual)
try:
ats = getattr(tokenizer, "additional_special_tokens", None)
ats_ids = getattr(tokenizer, "additional_special_tokens_ids", None)
if isinstance(ats, list) and isinstance(ats_ids, list) and tgt_code in ats:
idx = ats.index(tgt_code)
if 0 <= idx < len(ats_ids) and isinstance(ats_ids[idx], int):
return ats_ids[idx]
except Exception:
pass
# 5) último recurso: usa eos/bos para no romper generate()
LOG.warning("No pude resolver lang code id para '%s'. Uso fallback (eos/bos).", tgt_code)
return getattr(tokenizer, "eos_token_id", None) or getattr(tokenizer, "bos_token_id", None) or 0
# ---------- Traducción base ----------
@torch.inference_mode()
def translate_text(src_lang: str, tgt_lang: str, text: str, num_beams: int = 1, _tries: int = 0) -> str:
"""
Traduce un texto (usando troceo por tokens si excede MAX_SRC_TOKENS).
Se usa para títulos y como núcleo para chunks de artículos.
"""
if not text or not text.strip():
return ""
tok, mdl, device = get_universal_components()
src_code = map_to_nllb(src_lang) or "eng_Latn"
tgt_code = map_to_nllb(tgt_lang) or "spa_Latn"
# Configura idioma origen (si la prop existe)
try:
tok.src_lang = src_code
except Exception:
pass
forced_bos = _forced_bos_id(tok, mdl, tgt_code)
safe_len = _safe_src_len(tok)
parts = _token_chunks(tok, text, safe_len)
outs: List[str] = []
try:
autocast_ctx = torch.amp.autocast("cuda", dtype=torch.float16) if device.type == "cuda" else contextlib.nullcontext()
for p in parts:
enc = tok(p, return_tensors="pt", truncation=True, max_length=safe_len)
enc = {k: v.to(device) for k, v in enc.items()}
gen_kwargs = dict(
forced_bos_token_id=forced_bos,
max_new_tokens=MAX_NEW_TOKENS,
num_beams=max(1, int(num_beams)),
do_sample=False,
use_cache=False, # ↓ memoria
)
if int(num_beams) > 1:
gen_kwargs["early_stopping"] = True
with autocast_ctx:
generated = mdl.generate(**enc, **gen_kwargs)
out = tok.batch_decode(generated, skip_special_tokens=True)[0].strip()
outs.append(out)
del enc, generated
if device.type == "cuda":
_free_cuda()
return "\n".join([o for o in outs if o]).strip()
except Exception as e:
if device.type == "cuda" and _is_cuda_mem_error(e) and _tries < 2:
LOG.warning("CUDA OOM/allocator: intento de recuperación %d. Detalle: %s", _tries + 1, e)
# desactiva CUDA y relanza en CPU
global _MODEL, _DEVICE, _CUDA_DISABLED
_CUDA_DISABLED = True
try:
if _MODEL is not None:
del _MODEL
except Exception:
pass
_free_cuda()
_MODEL = None
_DEVICE = None
time.sleep(1.0)
return translate_text(src_lang, tgt_lang, text, num_beams=num_beams, _tries=_tries + 1)
raise
# ---------- Chunking por frases para artículos ----------
def _sent_token_len(tokenizer, sent: str) -> int:
return len(tokenizer(sent, add_special_tokens=False).input_ids)
def _pack_sentences_to_token_chunks(
tokenizer, sentences: List[str], max_tokens: int, overlap_sents: int = 0
) -> List[List[str]]:
chunks: List[List[str]] = []
cur: List[str] = []
cur_tokens = 0
for s in sentences:
slen = _sent_token_len(tokenizer, s)
if slen > max_tokens:
# Si una sola frase excede el límite, córtala por tokens como último recurso
ids = tokenizer(s, add_special_tokens=False).input_ids
step = max_tokens
for i in range(0, len(ids), step):
sub = tokenizer.decode(ids[i:i+step], skip_special_tokens=True)
if cur:
chunks.append(cur)
cur = []
cur_tokens = 0
chunks.append([sub])
continue
if cur_tokens + slen <= max_tokens:
cur.append(s); cur_tokens += slen
else:
if cur:
chunks.append(cur)
if overlap_sents > 0 and len(cur) > 0:
overlap = cur[-overlap_sents:]
cur = overlap + [s]
cur_tokens = sum(_sent_token_len(tokenizer, x) for x in cur)
else:
cur = [s]; cur_tokens = slen
if cur:
chunks.append(cur)
return chunks
def _smart_concatenate(parts: List[str], tail_window: int = 120) -> str:
"""Une partes evitando duplicados obvios en el borde (heurística ligera)."""
if not parts:
return ""
out = parts[0]
for nxt in parts[1:]:
tail = out[-tail_window:]
cut = 0
for k in range(min(len(tail), len(nxt)), 20, -1):
if nxt.startswith(tail[-k:]):
cut = k
break
out += ("" if cut == 0 else nxt[cut:]) if nxt else ""
return out
def translate_article_full(
src_lang: str,
tgt_lang: str,
text: str,
num_beams: int,
) -> str:
"""
Traduce un artículo completo:
- Divide por frases (sin look-behind variable)
- Empaqueta en chunks <= límite de tokens
- Traduce chunk a chunk (usa translate_text internamente)
- Une con heurística para evitar duplicados en bordes
"""
if not text or not text.strip():
return ""
if not CHUNK_BY_SENTENCES:
# Ruta rápida: una sola pasada con truncamiento interno
return translate_text(src_lang, tgt_lang, text, num_beams=num_beams)
tok, _, _ = get_universal_components()
safe_len = _safe_src_len(tok)
max_chunk_tokens = min(CHUNK_MAX_TOKENS, safe_len)
sents = split_into_sentences(text)
if not sents:
return ""
chunks_sents = _pack_sentences_to_token_chunks(
tok, sents, max_tokens=max_chunk_tokens, overlap_sents=CHUNK_OVERLAP_SENTS
)
translated_parts: List[str] = []
for group in chunks_sents:
chunk_text = " ".join(group)
translated = translate_text(src_lang, tgt_lang, chunk_text, num_beams=num_beams)
translated_parts.append(translated)
return _smart_concatenate([p for p in translated_parts if p])
# ---------- Procesamiento por lotes ----------
def process_batch(conn, rows):
for r in rows:
tr_id = r["tr_id"]
lang_to = normalize_lang(r["lang_to"], "es") or "es"
lang_from = normalize_lang(r["lang_from"]) or detect_lang(r["titulo"] or "", r["resumen"] or "") or "en"
title = (r["titulo"] or "").strip()
body = (r["resumen"] or "").strip()
# Si ya está en el mismo idioma, copia tal cual
if (map_to_nllb(lang_from) or "eng_Latn") == (map_to_nllb(lang_to) or "spa_Latn"):
mark_done(conn, tr_id, title, body, lang_from)
continue
try:
# Títulos: cortos, traducción directa (beams más altos si quieres)
title_tr = translate_text(lang_from, lang_to, title, num_beams=NUM_BEAMS_TITLE) if title else ""
# Cuerpo/resumen: artículo completo con chunking por frases
body_tr = translate_article_full(lang_from, lang_to, body, num_beams=NUM_BEAMS_BODY) if body else ""
# Si la "traducción" es igual al original, déjala vacía
if _norm(title_tr) == _norm(title):
title_tr = ""
if _norm(body_tr) == _norm(body):
body_tr = ""
mark_done(conn, tr_id, title_tr, body_tr, lang_from)
except Exception as e:
LOG.exception("Error traduciendo fila")
mark_error(conn, tr_id, str(e))
def main():
LOG.info(
"Arrancando worker de traducción (NLLB). TARGET_LANGS=%s, BATCH=%s, ENQUEUE=%s, DEVICE=%s, "
"BEAMS(title/body)=%s/%s, CHUNK_BY_SENTENCES=%s, CHUNK_MAX_TOKENS=%s, OVERLAP_SENTS=%s",
TARGET_LANGS, BATCH_SIZE, ENQUEUE_MAX, DEVICE_CFG, NUM_BEAMS_TITLE, NUM_BEAMS_BODY,
CHUNK_BY_SENTENCES, CHUNK_MAX_TOKENS, CHUNK_OVERLAP_SENTS
)
# Pre-carga el modelo una vez para reservar memoria de forma limpia
get_universal_components()
while True:
any_work = False
with get_conn() as conn:
ensure_indexes(conn)
for lt in TARGET_LANGS:
lt = normalize_lang(lt, "es") or "es"
ensure_pending(conn, lt, ENQUEUE_MAX)
while True:
rows = fetch_pending_batch(conn, lt, BATCH_SIZE)
if not rows:
break
any_work = True
LOG.info("[%s] Procesando %d elementos…", lt, len(rows))
process_batch(conn, rows)
if not any_work:
time.sleep(SLEEP_IDLE)
if __name__ == "__main__":
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
main()