import os import time import logging import contextlib import re from typing import List, Optional import psycopg2 import psycopg2.extras from psycopg2.extras import execute_values import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from langdetect import detect, DetectorFactory DetectorFactory.seed = 0 logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s") LOG = logging.getLogger(__name__) DB_CONFIG = { "host": os.environ.get("DB_HOST", "localhost"), "port": int(os.environ.get("DB_PORT", 5432)), "dbname": os.environ.get("DB_NAME", "rss"), "user": os.environ.get("DB_USER", "rss"), "password": os.environ.get("DB_PASS", "x"), } def _env_list(name: str, *fallbacks: str, default: str = "es") -> List[str]: raw = None for key in (name, *fallbacks): raw = os.environ.get(key) if raw: break raw = raw if raw is not None else default return [s.strip() for s in raw.split(",") if s and s.strip()] def _env_int(name: str, *fallbacks: str, default: int = 8) -> int: for key in (name, *fallbacks): val = os.environ.get(key) if val: try: return int(val) except ValueError: pass return default def _env_float(name: str, *fallbacks: str, default: float = 5.0) -> float: for key in (name, *fallbacks): val = os.environ.get(key) if val: try: return float(val) except ValueError: pass return default def _env_str(name: str, *fallbacks: str, default: Optional[str] = None) -> Optional[str]: for key in (name, *fallbacks): val = os.environ.get(key) if val: return val return default def _env_bool(name: str, default: bool = False) -> bool: val = os.environ.get(name) if val is None: return default return str(val).strip().lower() in ("1", "true", "yes", "y", "on") TARGET_LANGS = _env_list("TARGET_LANGS", "TRANSLATE_TO", default="es") BATCH_SIZE = _env_int("BATCH", "TRANSLATOR_BATCH", "TRANSLATE_BATCH", default=8) ENQUEUE_MAX = _env_int("ENQUEUE", "TRANSLATOR_ENQUEUE", "TRANSLATE_ENQUEUE", default=200) SLEEP_IDLE = _env_float("SLEEP_IDLE", "TRANSLATOR_SLEEP_IDLE", "TRANSLATE_SLEEP_IDLE", default=5.0) DEVICE_CFG = (_env_str("DEVICE", default="auto") or "auto").lower() MAX_SRC_TOKENS = _env_int("MAX_SRC_TOKENS", default=512) MAX_NEW_TOKENS = _env_int("MAX_NEW_TOKENS", default=256) def _beams_from_env(): nb_global = os.environ.get("NUM_BEAMS") has_title = os.environ.get("NUM_BEAMS_TITLE") is not None has_body = os.environ.get("NUM_BEAMS_BODY") is not None if nb_global and not has_title and not has_body: try: v = max(1, int(nb_global)) return v, v except ValueError: pass return _env_int("NUM_BEAMS_TITLE", default=2), _env_int("NUM_BEAMS_BODY", default=1) NUM_BEAMS_TITLE, NUM_BEAMS_BODY = _beams_from_env() UNIVERSAL_MODEL = _env_str("UNIVERSAL_MODEL", default="facebook/nllb-200-distilled-600M") CHUNK_BY_SENTENCES = _env_bool("CHUNK_BY_SENTENCES", default=True) CHUNK_MAX_TOKENS = _env_int("CHUNK_MAX_TOKENS", default=900) CHUNK_OVERLAP_SENTS = _env_int("CHUNK_OVERLAP_SENTS", default=0) _ABBR = ("Sr", "Sra", "Dr", "Dra", "Ing", "Lic", "pág", "etc") _ABBR_MARK = "§" _SENT_SPLIT_RE = re.compile( r'(?<=[\.!\?…])\s+(?=["“\(\[A-ZÁÉÍÓÚÑÄÖÜ0-9])|(?:\n{2,})' ) NLLB_LANG = { "es": "spa_Latn", "en": "eng_Latn", "fr": "fra_Latn", "de": "deu_Latn", "it": "ita_Latn", "pt": "por_Latn", "nl": "nld_Latn", "sv": "swe_Latn", "da": "dan_Latn", "fi": "fin_Latn", "no": "nob_Latn", "nb": "nob_Latn", "nn": "nno_Latn", "pl": "pol_Latn", "cs": "ces_Latn", "sk": "slk_Latn", "sl": "slv_Latn", "hu": "hun_Latn", "ro": "ron_Latn", "bg": "bul_Cyrl", "el": "ell_Grek", "ru": "rus_Cyrl", "uk": "ukr_Cyrl", "hr": "hrv_Latn", "sr": "srp_Cyrl", "bs": "bos_Latn", "tr": "tur_Latn", "ar": "arb_Arab", "fa": "pes_Arab", "he": "heb_Hebr", "zh": "zho_Hans", "ja": "jpn_Jpan", "ko": "kor_Hang", "vi": "vie_Latn", "th": "tha_Thai", "id": "ind_Latn", "ms": "zsm_Latn", "pt-br": "por_Latn", "pt-pt": "por_Latn", } def _protect_abbrev(text: str) -> str: t = re.sub(r"\b([A-ZÁÉÍÓÚÑÄÖÜ])\.", r"\1" + _ABBR_MARK, text) pat = r"\b(?:" + "|".join(map(re.escape, _ABBR)) + r")\." t = re.sub(pat, lambda m: m.group(0)[:-1] + _ABBR_MARK, t, flags=re.IGNORECASE) return t def _restore_abbrev(text: str) -> str: return text.replace(_ABBR_MARK, ".") def split_into_sentences(text: str) -> List[str]: text = (text or "").strip() if not text: return [] protected = _protect_abbrev(text) parts = [p.strip() for p in _SENT_SPLIT_RE.split(protected) if p and p.strip()] parts = [_restore_abbrev(p) for p in parts] merged: List[str] = [] for p in parts: if merged and len(p) < 40: merged[-1] = merged[-1] + " " + p else: merged.append(p) return merged def map_to_nllb(code: Optional[str]) -> Optional[str]: if not code: return None code = code.strip().lower() if code in NLLB_LANG: return NLLB_LANG[code] return f"{code}_Latn" def normalize_lang(code: Optional[str], default: Optional[str] = None) -> Optional[str]: if not code: return default code = code.strip().lower() return code if code else default def get_conn(): return psycopg2.connect(**DB_CONFIG) def ensure_indexes(conn): with conn.cursor() as cur: cur.execute( """ CREATE INDEX IF NOT EXISTS traducciones_lang_to_status_idx ON traducciones (lang_to, status); CREATE INDEX IF NOT EXISTS traducciones_status_idx ON traducciones (status); """ ) conn.commit() def ensure_pending(conn, lang_to: str, enqueue_limit: int): with conn.cursor() as cur: cur.execute( """ INSERT INTO traducciones (noticia_id, lang_from, lang_to, status) SELECT sub.id, NULL, %s, 'pending' FROM ( SELECT n.id FROM noticias n LEFT JOIN traducciones t ON t.noticia_id = n.id AND t.lang_to = %s WHERE t.id IS NULL ORDER BY n.fecha DESC NULLS LAST, n.id LIMIT %s ) AS sub; """, (lang_to, lang_to, enqueue_limit), ) conn.commit() def fetch_pending_batch(conn, lang_to: str, batch_size: int): with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute( """ SELECT t.id AS tr_id, t.noticia_id, t.lang_from, t.lang_to, n.titulo, n.resumen FROM traducciones t JOIN noticias n ON n.id = t.noticia_id WHERE t.lang_to = %s AND t.status = 'pending' ORDER BY t.id LIMIT %s; """, (lang_to, batch_size), ) rows = cur.fetchall() if rows: ids = [r["tr_id"] for r in rows] with conn.cursor() as cur2: cur2.execute("UPDATE traducciones SET status='processing' WHERE id = ANY(%s)", (ids,)) conn.commit() return rows def detect_lang(text1: str, text2: str) -> Optional[str]: txt = (text1 or "").strip() or (text2 or "").strip() if not txt: return None try: return detect(txt) except Exception: return None _TOKENIZER: Optional[AutoTokenizer] = None _MODEL: Optional[AutoModelForSeq2SeqLM] = None _DEVICE: Optional[torch.device] = None _CUDA_FAILS: int = 0 _CUDA_DISABLED: bool = False def _resolve_device() -> torch.device: global _CUDA_DISABLED if _CUDA_DISABLED: return torch.device("cpu") if DEVICE_CFG == "cpu": return torch.device("cpu") if DEVICE_CFG == "cuda": return torch.device("cuda" if torch.cuda.is_available() else "cpu") return torch.device("cuda" if torch.cuda.is_available() else "cpu") def _is_cuda_mem_error(exc: Exception) -> bool: s = str(exc) return ("CUDA out of memory" in s) or ("CUDACachingAllocator" in s) or ("expandable_segment" in s) def _free_cuda(): if torch.cuda.is_available(): try: torch.cuda.empty_cache() torch.cuda.ipc_collect() except Exception: pass def _load_model_on(device: torch.device): global _TOKENIZER, _MODEL, _DEVICE dtype = torch.float16 if device.type == "cuda" else torch.float32 LOG.info("Cargando modelo universal %s (device=%s, dtype=%s)", UNIVERSAL_MODEL, device, dtype) tok = AutoTokenizer.from_pretrained(UNIVERSAL_MODEL) mdl = AutoModelForSeq2SeqLM.from_pretrained( UNIVERSAL_MODEL, torch_dtype=dtype, low_cpu_mem_usage=True, ) try: mdl.config.use_cache = False except Exception: pass mdl.to(device) mdl.eval() _TOKENIZER, _MODEL, _DEVICE = tok, mdl, device def get_universal_components(): global _TOKENIZER, _MODEL, _DEVICE, _CUDA_FAILS, _CUDA_DISABLED if _MODEL is not None and _DEVICE is not None: return _TOKENIZER, _MODEL, _DEVICE dev = _resolve_device() try: _load_model_on(dev) return _TOKENIZER, _MODEL, _DEVICE except Exception as e: LOG.warning("Fallo cargando modelo en %s: %s", dev, e) if dev.type == "cuda" and _is_cuda_mem_error(e): _CUDA_FAILS += 1 _CUDA_DISABLED = True _free_cuda() LOG.warning("Deshabilitando CUDA y reintentando en CPU (fallos CUDA=%d)", _CUDA_FAILS) _load_model_on(torch.device("cpu")) return _TOKENIZER, _MODEL, _DEVICE _load_model_on(torch.device("cpu")) return _TOKENIZER, _MODEL, _DEVICE def _safe_src_len(tokenizer) -> int: model_max = getattr(tokenizer, "model_max_length", 1024) or 1024 return min(MAX_SRC_TOKENS, int(model_max) - 16) def _token_chunks(tokenizer, text: str, max_tokens: int) -> List[str]: if not text: return [] ids = tokenizer.encode(text, add_special_tokens=False) if len(ids) <= max_tokens: return [text] chunks = [] for i in range(0, len(ids), max_tokens): sub = ids[i: i + max_tokens] piece = tokenizer.decode(sub, skip_special_tokens=True, clean_up_tokenization_spaces=True) if piece.strip(): chunks.append(piece.strip()) return chunks def _norm(s: str) -> str: import re as _re return _re.sub(r"\W+", "", (s or "").lower()).strip() def _forced_bos_id(tokenizer: AutoTokenizer, model: AutoModelForSeq2SeqLM, tgt_code: str) -> int: try: mapping = getattr(tokenizer, "lang_code_to_id", None) if isinstance(mapping, dict): tid = mapping.get(tgt_code) if isinstance(tid, int): return tid except Exception: pass try: mapping = getattr(getattr(model, "config", None), "lang_code_to_id", None) if isinstance(mapping, dict): tid = mapping.get(tgt_code) if isinstance(tid, int): return tid except Exception: pass try: tid = tokenizer.convert_tokens_to_ids(tgt_code) if isinstance(tid, int) and tid not in (-1, getattr(tokenizer, "unk_token_id", -1)): return tid except Exception: pass try: ats = getattr(tokenizer, "additional_special_tokens", None) ats_ids = getattr(tokenizer, "additional_special_tokens_ids", None) if isinstance(ats, list) and isinstance(ats_ids, list) and tgt_code in ats: idx = ats.index(tgt_code) if 0 <= idx < len(ats_ids) and isinstance(ats_ids[idx], int): return ats_ids[idx] except Exception: pass LOG.warning("No pude resolver lang code id para '%s'. Uso fallback (eos/bos).", tgt_code) return getattr(tokenizer, "eos_token_id", None) or getattr(tokenizer, "bos_token_id", None) or 0 @torch.inference_mode() def _translate_texts_simple( src_lang: str, tgt_lang: str, texts: List[str], num_beams: int = 1, _tries: int = 0, ) -> List[str]: if not texts: return [] cleaned = [(t or "").strip() for t in texts] if all(not t for t in cleaned): return ["" for _ in cleaned] tok, mdl, device = get_universal_components() src_code = map_to_nllb(src_lang) or "eng_Latn" tgt_code = map_to_nllb(tgt_lang) or "spa_Latn" try: tok.src_lang = src_code except Exception: pass forced_bos = _forced_bos_id(tok, mdl, tgt_code) safe_len = _safe_src_len(tok) try: autocast_ctx = ( torch.amp.autocast("cuda", dtype=torch.float16) if device.type == "cuda" else contextlib.nullcontext() ) enc = tok( cleaned, return_tensors="pt", padding=True, truncation=True, max_length=safe_len, ) enc = {k: v.to(device) for k, v in enc.items()} gen_kwargs = dict( forced_bos_token_id=forced_bos, max_new_tokens=MAX_NEW_TOKENS, num_beams=max(1, int(num_beams)), do_sample=False, use_cache=False, ) if int(num_beams) > 1: gen_kwargs["early_stopping"] = True with autocast_ctx: generated = mdl.generate(**enc, **gen_kwargs) outs = tok.batch_decode(generated, skip_special_tokens=True) outs = [o.strip() for o in outs] del enc, generated if device.type == "cuda": _free_cuda() return outs except Exception as e: if device.type == "cuda" and _is_cuda_mem_error(e) and _tries < 2: LOG.warning("CUDA OOM/allocator (batch): intento de recuperación %d. Detalle: %s", _tries + 1, e) global _MODEL, _DEVICE, _CUDA_DISABLED _CUDA_DISABLED = True try: if _MODEL is not None: del _MODEL except Exception: pass _free_cuda() _MODEL = None _DEVICE = None time.sleep(1.0) return _translate_texts_simple(src_lang, tgt_lang, texts, num_beams=num_beams, _tries=_tries + 1) raise @torch.inference_mode() def translate_text(src_lang: str, tgt_lang: str, text: str, num_beams: int = 1, _tries: int = 0) -> str: if not text or not text.strip(): return "" tok, mdl, device = get_universal_components() src_code = map_to_nllb(src_lang) or "eng_Latn" tgt_code = map_to_nllb(tgt_lang) or "spa_Latn" try: tok.src_lang = src_code except Exception: pass forced_bos = _forced_bos_id(tok, mdl, tgt_code) safe_len = _safe_src_len(tok) parts = _token_chunks(tok, text, safe_len) outs: List[str] = [] try: autocast_ctx = torch.amp.autocast("cuda", dtype=torch.float16) if device.type == "cuda" else contextlib.nullcontext() for p in parts: enc = tok(p, return_tensors="pt", truncation=True, max_length=safe_len) enc = {k: v.to(device) for k, v in enc.items()} gen_kwargs = dict( forced_bos_token_id=forced_bos, max_new_tokens=MAX_NEW_TOKENS, num_beams=max(1, int(num_beams)), do_sample=False, use_cache=False, ) if int(num_beams) > 1: gen_kwargs["early_stopping"] = True with autocast_ctx: generated = mdl.generate(**enc, **gen_kwargs) out = tok.batch_decode(generated, skip_special_tokens=True)[0].strip() outs.append(out) del enc, generated if device.type == "cuda": _free_cuda() return "\n".join([o for o in outs if o]).strip() except Exception as e: if device.type == "cuda" and _is_cuda_mem_error(e) and _tries < 2: LOG.warning("CUDA OOM/allocator: intento de recuperación %d. Detalle: %s", _tries + 1, e) global _MODEL, _DEVICE, _CUDA_DISABLED _CUDA_DISABLED = True try: if _MODEL is not None: del _MODEL except Exception: pass _free_cuda() _MODEL = None _DEVICE = None time.sleep(1.0) return translate_text(src_lang, tgt_lang, text, num_beams=num_beams, _tries=_tries + 1) raise def _sent_token_len(tokenizer, sent: str) -> int: return len(tokenizer(sent, add_special_tokens=False).input_ids) def _pack_sentences_to_token_chunks( tokenizer, sentences: List[str], max_tokens: int, overlap_sents: int = 0 ) -> List[List[str]]: chunks: List[List[str]] = [] cur: List[str] = [] cur_tokens = 0 for s in sentences: slen = _sent_token_len(tokenizer, s) if slen > max_tokens: ids = tokenizer(s, add_special_tokens=False).input_ids step = max_tokens for i in range(0, len(ids), step): sub = tokenizer.decode(ids[i: i + step], skip_special_tokens=True) if cur: chunks.append(cur) cur = [] cur_tokens = 0 chunks.append([sub]) continue if cur_tokens + slen <= max_tokens: cur.append(s) cur_tokens += slen else: if cur: chunks.append(cur) if overlap_sents > 0 and len(cur) > 0: overlap = cur[-overlap_sents:] cur = overlap + [s] cur_tokens = sum(_sent_token_len(tokenizer, x) for x in cur) else: cur = [s] cur_tokens = slen if cur: chunks.append(cur) return chunks def _smart_concatenate(parts: List[str], tail_window: int = 120) -> str: if not parts: return "" out = parts[0] for nxt in parts[1:]: tail = out[-tail_window:] cut = 0 for k in range(min(len(tail), len(nxt)), 20, -1): if nxt.startswith(tail[-k:]): cut = k break out += ("" if cut == 0 else nxt[cut:]) if nxt else "" return out def translate_articles_full_batch( src_lang: str, tgt_lang: str, texts: List[str], num_beams: int, ) -> List[str]: if not texts: return [] if not CHUNK_BY_SENTENCES: return _translate_texts_simple(src_lang, tgt_lang, texts, num_beams=num_beams) tok, _, _ = get_universal_components() safe_len = _safe_src_len(tok) max_chunk_tokens = min(CHUNK_MAX_TOKENS, safe_len) all_chunk_texts: List[str] = [] per_article_chunk_ids: List[List[int]] = [] for text in texts: text = (text or "").strip() if not text: per_article_chunk_ids.append([]) continue sents = split_into_sentences(text) if not sents: per_article_chunk_ids.append([]) continue chunks_sents = _pack_sentences_to_token_chunks( tok, sents, max_tokens=max_chunk_tokens, overlap_sents=CHUNK_OVERLAP_SENTS, ) ids_for_this_article: List[int] = [] for group in chunks_sents: chunk_text = " ".join(group).strip() if not chunk_text: continue idx = len(all_chunk_texts) all_chunk_texts.append(chunk_text) ids_for_this_article.append(idx) per_article_chunk_ids.append(ids_for_this_article) if not all_chunk_texts: return ["" for _ in texts] translated_chunks = _translate_texts_simple( src_lang, tgt_lang, all_chunk_texts, num_beams=num_beams, ) outs: List[str] = [] for chunk_ids in per_article_chunk_ids: if not chunk_ids: outs.append("") continue parts = [translated_chunks[i] for i in chunk_ids] outs.append(_smart_concatenate([p for p in parts if p])) return outs def translate_article_full( src_lang: str, tgt_lang: str, text: str, num_beams: int, ) -> str: if not text or not text.strip(): return "" if not CHUNK_BY_SENTENCES: return translate_text(src_lang, tgt_lang, text, num_beams=num_beams) tok, _, _ = get_universal_components() safe_len = _safe_src_len(tok) max_chunk_tokens = min(CHUNK_MAX_TOKENS, safe_len) sents = split_into_sentences(text) if not sents: return "" chunks_sents = _pack_sentences_to_token_chunks( tok, sents, max_tokens=max_chunk_tokens, overlap_sents=CHUNK_OVERLAP_SENTS ) translated_parts: List[str] = [] for group in chunks_sents: chunk_text = " ".join(group) translated = translate_text(src_lang, tgt_lang, chunk_text, num_beams=num_beams) translated_parts.append(translated) return _smart_concatenate([p for p in translated_parts if p]) def process_batch(conn, rows): batch_size = len(rows) LOG.info("Iniciando traducción de batch con %d filas…", batch_size) t0 = time.time() done_rows = [] error_rows = [] enriched_rows = [] for r in rows: tr_id = r["tr_id"] lang_to = normalize_lang(r["lang_to"], "es") or "es" lang_from = normalize_lang(r["lang_from"]) or detect_lang(r["titulo"] or "", r["resumen"] or "") or "en" title = (r["titulo"] or "").strip() body = (r["resumen"] or "").strip() src_code = map_to_nllb(lang_from) or "eng_Latn" tgt_code = map_to_nllb(lang_to) or "spa_Latn" if src_code == tgt_code: done_rows.append((title, body, lang_from, tr_id)) continue enriched_rows.append( { "tr_id": tr_id, "lang_from": lang_from, "lang_to": lang_to, "title": title, "body": body, } ) from collections import defaultdict groups = defaultdict(list) for er in enriched_rows: key = (er["lang_from"], er["lang_to"]) groups[key].append(er) for (lang_from, lang_to), items in groups.items(): titles = [it["title"] for it in items] bodies = [it["body"] for it in items] try: titles_tr = _translate_texts_simple(lang_from, lang_to, titles, num_beams=NUM_BEAMS_TITLE) bodies_tr = translate_articles_full_batch(lang_from, lang_to, bodies, num_beams=NUM_BEAMS_BODY) for it, t_tr, b_tr in zip(items, titles_tr, bodies_tr): title_orig = it["title"] body_orig = it["body"] if _norm(t_tr) == _norm(title_orig): t_tr = "" if _norm(b_tr) == _norm(body_orig): b_tr = "" done_rows.append((t_tr, b_tr, lang_from, it["tr_id"])) except Exception as e: LOG.exception("Error traduciendo lote %s -> %s", lang_from, lang_to) err_msg = str(e)[:1500] for it in items: error_rows.append((err_msg, it["tr_id"])) with conn.cursor() as cur: if done_rows: execute_values( cur, """ UPDATE traducciones AS t SET titulo_trad = v.titulo_trad, resumen_trad = v.resumen_trad, lang_from = COALESCE(t.lang_from, v.lang_from), status = 'done', error = NULL FROM (VALUES %s) AS v(titulo_trad, resumen_trad, lang_from, id) WHERE t.id = v.id; """, done_rows, ) if error_rows: execute_values( cur, """ UPDATE traducciones AS t SET status = 'error', error = v.error FROM (VALUES %s) AS v(error, id) WHERE t.id = v.id; """, error_rows, ) conn.commit() dt = time.time() - t0 try: _, _, device = get_universal_components() dev_label = device.type.upper() if device is not None else "UNK" except Exception: dev_label = "UNK" if batch_size > 0: LOG.info( "[%s] Batch de %d filas traducido en %.2f s (%.2f s/noticia)", dev_label, batch_size, dt, dt / batch_size, ) else: LOG.info( "[%s] Batch vacío, nada que traducir (%.2f s)", dev_label, dt, ) def main(): LOG.info( "Arrancando worker de traducción (NLLB). TARGET_LANGS=%s, BATCH=%s, ENQUEUE=%s, DEVICE=%s, " "BEAMS(title/body)=%s/%s, CHUNK_BY_SENTENCES=%s, CHUNK_MAX_TOKENS=%s, OVERLAP_SENTS=%s", TARGET_LANGS, BATCH_SIZE, ENQUEUE_MAX, DEVICE_CFG, NUM_BEAMS_TITLE, NUM_BEAMS_BODY, CHUNK_BY_SENTENCES, CHUNK_MAX_TOKENS, CHUNK_OVERLAP_SENTS, ) get_universal_components() while True: any_work = False with get_conn() as conn: ensure_indexes(conn) for lt in TARGET_LANGS: lt = normalize_lang(lt, "es") or "es" ensure_pending(conn, lt, ENQUEUE_MAX) while True: rows = fetch_pending_batch(conn, lt, BATCH_SIZE) if not rows: break any_work = True LOG.info("[%s] Procesando %d elementos…", lt, len(rows)) process_batch(conn, rows) if not any_work: time.sleep(SLEEP_IDLE) if __name__ == "__main__": os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") main()