rss/translation_worker.py
2025-11-21 04:42:02 +01:00

671 lines
20 KiB
Python

import os
import time
import logging
import contextlib
import re
from typing import List, Optional
import psycopg2
import psycopg2.extras
from psycopg2.extras import execute_values
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from langdetect import detect, DetectorFactory
DetectorFactory.seed = 0
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
LOG = logging.getLogger(__name__)
DB_CONFIG = {
"host": os.environ.get("DB_HOST", "localhost"),
"port": int(os.environ.get("DB_PORT", 5432)),
"dbname": os.environ.get("DB_NAME", "rss"),
"user": os.environ.get("DB_USER", "rss"),
"password": os.environ.get("DB_PASS", "x"),
}
def _env_list(name: str, *fallbacks: str, default: str = "es") -> List[str]:
raw = None
for key in (name, *fallbacks):
raw = os.environ.get(key)
if raw:
break
raw = raw if raw is not None else default
return [s.strip() for s in raw.split(",") if s and s.strip()]
def _env_int(name: str, *fallbacks: str, default: int = 8) -> int:
for key in (name, *fallbacks):
val = os.environ.get(key)
if val:
try:
return int(val)
except ValueError:
pass
return default
def _env_float(name: str, *fallbacks: str, default: float = 5.0) -> float:
for key in (name, *fallbacks):
val = os.environ.get(key)
if val:
try:
return float(val)
except ValueError:
pass
return default
def _env_str(name: str, *fallbacks: str, default: Optional[str] = None) -> Optional[str]:
for key in (name, *fallbacks):
val = os.environ.get(key)
if val:
return val
return default
def _env_bool(name: str, default: bool = False) -> bool:
val = os.environ.get(name)
if val is None:
return default
return str(val).strip().lower() in ("1", "true", "yes", "y", "on")
TARGET_LANGS = _env_list("TARGET_LANGS", "TRANSLATE_TO", default="es")
BATCH_SIZE = _env_int("BATCH", "TRANSLATOR_BATCH", "TRANSLATE_BATCH", default=8)
ENQUEUE_MAX = _env_int("ENQUEUE", "TRANSLATOR_ENQUEUE", "TRANSLATE_ENQUEUE", default=200)
SLEEP_IDLE = _env_float("SLEEP_IDLE", "TRANSLATOR_SLEEP_IDLE", "TRANSLATE_SLEEP_IDLE", default=5.0)
DEVICE_CFG = (_env_str("DEVICE", default="auto") or "auto").lower()
MAX_SRC_TOKENS = _env_int("MAX_SRC_TOKENS", default=512)
MAX_NEW_TOKENS = _env_int("MAX_NEW_TOKENS", default=256)
def _beams_from_env():
nb_global = os.environ.get("NUM_BEAMS")
has_title = os.environ.get("NUM_BEAMS_TITLE") is not None
has_body = os.environ.get("NUM_BEAMS_BODY") is not None
if nb_global and not has_title and not has_body:
try:
v = max(1, int(nb_global))
return v, v
except ValueError:
pass
return _env_int("NUM_BEAMS_TITLE", default=2), _env_int("NUM_BEAMS_BODY", default=1)
NUM_BEAMS_TITLE, NUM_BEAMS_BODY = _beams_from_env()
UNIVERSAL_MODEL = _env_str("UNIVERSAL_MODEL", default="facebook/nllb-200-distilled-600M")
CHUNK_BY_SENTENCES = _env_bool("CHUNK_BY_SENTENCES", default=True)
CHUNK_MAX_TOKENS = _env_int("CHUNK_MAX_TOKENS", default=900)
CHUNK_OVERLAP_SENTS = _env_int("CHUNK_OVERLAP_SENTS", default=0)
_ABBR = ("Sr", "Sra", "Dr", "Dra", "Ing", "Lic", "pág", "etc")
_ABBR_MARK = "§"
_SENT_SPLIT_RE = re.compile(
r'(?<=[\.!\?…])\s+(?=["\(\[A-ZÁÉÍÓÚÑÄÖÜ0-9])|(?:\n{2,})'
)
NLLB_LANG = {
"es": "spa_Latn",
"en": "eng_Latn",
"fr": "fra_Latn",
"de": "deu_Latn",
"it": "ita_Latn",
"pt": "por_Latn",
"nl": "nld_Latn",
"sv": "swe_Latn",
"da": "dan_Latn",
"fi": "fin_Latn",
"no": "nob_Latn",
"nb": "nob_Latn",
"nn": "nno_Latn",
"pl": "pol_Latn",
"cs": "ces_Latn",
"sk": "slk_Latn",
"sl": "slv_Latn",
"hu": "hun_Latn",
"ro": "ron_Latn",
"bg": "bul_Cyrl",
"el": "ell_Grek",
"ru": "rus_Cyrl",
"uk": "ukr_Cyrl",
"hr": "hrv_Latn",
"sr": "srp_Cyrl",
"bs": "bos_Latn",
"tr": "tur_Latn",
"ar": "arb_Arab",
"fa": "pes_Arab",
"he": "heb_Hebr",
"zh": "zho_Hans",
"ja": "jpn_Jpan",
"ko": "kor_Hang",
"vi": "vie_Latn",
"th": "tha_Thai",
"id": "ind_Latn",
"ms": "zsm_Latn",
"pt-br": "por_Latn",
"pt-pt": "por_Latn",
}
def _protect_abbrev(text: str) -> str:
t = re.sub(r"\b([A-ZÁÉÍÓÚÑÄÖÜ])\.", r"\1" + _ABBR_MARK, text)
pat = r"\b(?:" + "|".join(map(re.escape, _ABBR)) + r")\."
t = re.sub(pat, lambda m: m.group(0)[:-1] + _ABBR_MARK, t, flags=re.IGNORECASE)
return t
def _restore_abbrev(text: str) -> str:
return text.replace(_ABBR_MARK, ".")
def split_into_sentences(text: str) -> List[str]:
text = (text or "").strip()
if not text:
return []
protected = _protect_abbrev(text)
parts = [p.strip() for p in _SENT_SPLIT_RE.split(protected) if p and p.strip()]
parts = [_restore_abbrev(p) for p in parts]
merged: List[str] = []
for p in parts:
if merged and len(p) < 40:
merged[-1] = merged[-1] + " " + p
else:
merged.append(p)
return merged
def map_to_nllb(code: Optional[str]) -> Optional[str]:
if not code:
return None
code = code.strip().lower()
if code in NLLB_LANG:
return NLLB_LANG[code]
return f"{code}_Latn"
def normalize_lang(code: Optional[str], default: Optional[str] = None) -> Optional[str]:
if not code:
return default
code = code.strip().lower()
return code if code else default
def get_conn():
return psycopg2.connect(**DB_CONFIG)
def ensure_indexes(conn):
with conn.cursor() as cur:
cur.execute(
"""
CREATE INDEX IF NOT EXISTS traducciones_lang_to_status_idx
ON traducciones (lang_to, status);
CREATE INDEX IF NOT EXISTS traducciones_status_idx
ON traducciones (status);
"""
)
conn.commit()
def ensure_pending(conn, lang_to: str, enqueue_limit: int):
with conn.cursor() as cur:
cur.execute(
"""
INSERT INTO traducciones (noticia_id, lang_from, lang_to, status)
SELECT sub.id, NULL, %s, 'pending'
FROM (
SELECT n.id
FROM noticias n
LEFT JOIN traducciones t
ON t.noticia_id = n.id AND t.lang_to = %s
WHERE t.id IS NULL
ORDER BY n.fecha DESC NULLS LAST, n.id
LIMIT %s
) AS sub;
""",
(lang_to, lang_to, enqueue_limit),
)
conn.commit()
def fetch_pending_batch(conn, lang_to: str, batch_size: int):
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
"""
SELECT t.id AS tr_id, t.noticia_id, t.lang_from, t.lang_to,
n.titulo, n.resumen
FROM traducciones t
JOIN noticias n ON n.id = t.noticia_id
WHERE t.lang_to = %s AND t.status = 'pending'
ORDER BY t.id
LIMIT %s;
""",
(lang_to, batch_size),
)
rows = cur.fetchall()
if rows:
ids = [r["tr_id"] for r in rows]
with conn.cursor() as cur2:
cur2.execute("UPDATE traducciones SET status='processing' WHERE id = ANY(%s)", (ids,))
conn.commit()
return rows
def detect_lang(text1: str, text2: str) -> Optional[str]:
txt = (text1 or "").strip() or (text2 or "").strip()
if not txt:
return None
try:
return detect(txt)
except Exception:
return None
_TOKENIZER: Optional[AutoTokenizer] = None
_MODEL: Optional[AutoModelForSeq2SeqLM] = None
_DEVICE: Optional[torch.device] = None
_CUDA_FAILS: int = 0
_CUDA_DISABLED: bool = False
def _resolve_device() -> torch.device:
global _CUDA_DISABLED
if _CUDA_DISABLED:
return torch.device("cpu")
if DEVICE_CFG == "cpu":
return torch.device("cpu")
if DEVICE_CFG == "cuda":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _is_cuda_mem_error(exc: Exception) -> bool:
s = str(exc)
return ("CUDA out of memory" in s) or ("CUDACachingAllocator" in s) or ("expandable_segment" in s)
def _free_cuda():
if torch.cuda.is_available():
try:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except Exception:
pass
def _load_model_on(device: torch.device):
global _TOKENIZER, _MODEL, _DEVICE
dtype = torch.float16 if device.type == "cuda" else torch.float32
LOG.info("Cargando modelo universal %s (device=%s, dtype=%s)", UNIVERSAL_MODEL, device, dtype)
tok = AutoTokenizer.from_pretrained(UNIVERSAL_MODEL)
mdl = AutoModelForSeq2SeqLM.from_pretrained(
UNIVERSAL_MODEL,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
try:
mdl.config.use_cache = False
except Exception:
pass
mdl.to(device)
mdl.eval()
_TOKENIZER, _MODEL, _DEVICE = tok, mdl, device
def get_universal_components():
global _TOKENIZER, _MODEL, _DEVICE, _CUDA_FAILS, _CUDA_DISABLED
if _MODEL is not None and _DEVICE is not None:
return _TOKENIZER, _MODEL, _DEVICE
dev = _resolve_device()
try:
_load_model_on(dev)
return _TOKENIZER, _MODEL, _DEVICE
except Exception as e:
LOG.warning("Fallo cargando modelo en %s: %s", dev, e)
if dev.type == "cuda" and _is_cuda_mem_error(e):
_CUDA_FAILS += 1
_CUDA_DISABLED = True
_free_cuda()
LOG.warning("Deshabilitando CUDA y reintentando en CPU (fallos CUDA=%d)", _CUDA_FAILS)
_load_model_on(torch.device("cpu"))
return _TOKENIZER, _MODEL, _DEVICE
_load_model_on(torch.device("cpu"))
return _TOKENIZER, _MODEL, _DEVICE
def _safe_src_len(tokenizer) -> int:
model_max = getattr(tokenizer, "model_max_length", 1024) or 1024
return min(MAX_SRC_TOKENS, int(model_max) - 16)
def _token_chunks(tokenizer, text: str, max_tokens: int) -> List[str]:
if not text:
return []
ids = tokenizer.encode(text, add_special_tokens=False)
if len(ids) <= max_tokens:
return [text]
chunks = []
for i in range(0, len(ids), max_tokens):
sub = ids[i : i + max_tokens]
piece = tokenizer.decode(sub, skip_special_tokens=True, clean_up_tokenization_spaces=True)
if piece.strip():
chunks.append(piece.strip())
return chunks
def _norm(s: str) -> str:
import re as _re
return _re.sub(r"\W+", "", (s or "").lower()).strip()
def _forced_bos_id(tokenizer: AutoTokenizer, model: AutoModelForSeq2SeqLM, tgt_code: str) -> int:
try:
mapping = getattr(tokenizer, "lang_code_to_id", None)
if isinstance(mapping, dict):
tid = mapping.get(tgt_code)
if isinstance(tid, int):
return tid
except Exception:
pass
try:
mapping = getattr(getattr(model, "config", None), "lang_code_to_id", None)
if isinstance(mapping, dict):
tid = mapping.get(tgt_code)
if isinstance(tid, int):
return tid
except Exception:
pass
try:
tid = tokenizer.convert_tokens_to_ids(tgt_code)
if isinstance(tid, int) and tid not in (-1, getattr(tokenizer, "unk_token_id", -1)):
return tid
except Exception:
pass
try:
ats = getattr(tokenizer, "additional_special_tokens", None)
ats_ids = getattr(tokenizer, "additional_special_tokens_ids", None)
if isinstance(ats, list) and isinstance(ats_ids, list) and tgt_code in ats:
idx = ats.index(tgt_code)
if 0 <= idx < len(ats_ids) and isinstance(ats_ids[idx], int):
return ats_ids[idx]
except Exception:
pass
LOG.warning("No pude resolver lang code id para '%s'. Uso fallback (eos/bos).", tgt_code)
return getattr(tokenizer, "eos_token_id", None) or getattr(tokenizer, "bos_token_id", None) or 0
@torch.inference_mode()
def translate_text(src_lang: str, tgt_lang: str, text: str, num_beams: int = 1, _tries: int = 0) -> str:
if not text or not text.strip():
return ""
tok, mdl, device = get_universal_components()
src_code = map_to_nllb(src_lang) or "eng_Latn"
tgt_code = map_to_nllb(tgt_lang) or "spa_Latn"
try:
tok.src_lang = src_code
except Exception:
pass
forced_bos = _forced_bos_id(tok, mdl, tgt_code)
safe_len = _safe_src_len(tok)
parts = _token_chunks(tok, text, safe_len)
outs: List[str] = []
try:
autocast_ctx = torch.amp.autocast("cuda", dtype=torch.float16) if device.type == "cuda" else contextlib.nullcontext()
for p in parts:
enc = tok(p, return_tensors="pt", truncation=True, max_length=safe_len)
enc = {k: v.to(device) for k, v in enc.items()}
gen_kwargs = dict(
forced_bos_token_id=forced_bos,
max_new_tokens=MAX_NEW_TOKENS,
num_beams=max(1, int(num_beams)),
do_sample=False,
use_cache=False,
)
if int(num_beams) > 1:
gen_kwargs["early_stopping"] = True
with autocast_ctx:
generated = mdl.generate(**enc, **gen_kwargs)
out = tok.batch_decode(generated, skip_special_tokens=True)[0].strip()
outs.append(out)
del enc, generated
if device.type == "cuda":
_free_cuda()
return "\n".join([o for o in outs if o]).strip()
except Exception as e:
if device.type == "cuda" and _is_cuda_mem_error(e) and _tries < 2:
LOG.warning("CUDA OOM/allocator: intento de recuperación %d. Detalle: %s", _tries + 1, e)
global _MODEL, _DEVICE, _CUDA_DISABLED
_CUDA_DISABLED = True
try:
if _MODEL is not None:
del _MODEL
except Exception:
pass
_free_cuda()
_MODEL = None
_DEVICE = None
time.sleep(1.0)
return translate_text(src_lang, tgt_lang, text, num_beams=num_beams, _tries=_tries + 1)
raise
def _sent_token_len(tokenizer, sent: str) -> int:
return len(tokenizer(sent, add_special_tokens=False).input_ids)
def _pack_sentences_to_token_chunks(
tokenizer, sentences: List[str], max_tokens: int, overlap_sents: int = 0
) -> List[List[str]]:
chunks: List[List[str]] = []
cur: List[str] = []
cur_tokens = 0
for s in sentences:
slen = _sent_token_len(tokenizer, s)
if slen > max_tokens:
ids = tokenizer(s, add_special_tokens=False).input_ids
step = max_tokens
for i in range(0, len(ids), step):
sub = tokenizer.decode(ids[i : i + step], skip_special_tokens=True)
if cur:
chunks.append(cur)
cur = []
cur_tokens = 0
chunks.append([sub])
continue
if cur_tokens + slen <= max_tokens:
cur.append(s)
cur_tokens += slen
else:
if cur:
chunks.append(cur)
if overlap_sents > 0 and len(cur) > 0:
overlap = cur[-overlap_sents:]
cur = overlap + [s]
cur_tokens = sum(_sent_token_len(tokenizer, x) for x in cur)
else:
cur = [s]
cur_tokens = slen
if cur:
chunks.append(cur)
return chunks
def _smart_concatenate(parts: List[str], tail_window: int = 120) -> str:
if not parts:
return ""
out = parts[0]
for nxt in parts[1:]:
tail = out[-tail_window:]
cut = 0
for k in range(min(len(tail), len(nxt)), 20, -1):
if nxt.startswith(tail[-k:]):
cut = k
break
out += ("" if cut == 0 else nxt[cut:]) if nxt else ""
return out
def translate_article_full(
src_lang: str,
tgt_lang: str,
text: str,
num_beams: int,
) -> str:
if not text or not text.strip():
return ""
if not CHUNK_BY_SENTENCES:
return translate_text(src_lang, tgt_lang, text, num_beams=num_beams)
tok, _, _ = get_universal_components()
safe_len = _safe_src_len(tok)
max_chunk_tokens = min(CHUNK_MAX_TOKENS, safe_len)
sents = split_into_sentences(text)
if not sents:
return ""
chunks_sents = _pack_sentences_to_token_chunks(
tok, sents, max_tokens=max_chunk_tokens, overlap_sents=CHUNK_OVERLAP_SENTS
)
translated_parts: List[str] = []
for group in chunks_sents:
chunk_text = " ".join(group)
translated = translate_text(src_lang, tgt_lang, chunk_text, num_beams=num_beams)
translated_parts.append(translated)
return _smart_concatenate([p for p in translated_parts if p])
def process_batch(conn, rows):
done_rows = []
error_rows = []
for r in rows:
tr_id = r["tr_id"]
lang_to = normalize_lang(r["lang_to"], "es") or "es"
lang_from = normalize_lang(r["lang_from"]) or detect_lang(r["titulo"] or "", r["resumen"] or "") or "en"
title = (r["titulo"] or "").strip()
body = (r["resumen"] or "").strip()
if (map_to_nllb(lang_from) or "eng_Latn") == (map_to_nllb(lang_to) or "spa_Latn"):
done_rows.append((title, body, lang_from, tr_id))
continue
try:
title_tr = translate_text(lang_from, lang_to, title, num_beams=NUM_BEAMS_TITLE) if title else ""
body_tr = translate_article_full(lang_from, lang_to, body, num_beams=NUM_BEAMS_BODY) if body else ""
if _norm(title_tr) == _norm(title):
title_tr = ""
if _norm(body_tr) == _norm(body):
body_tr = ""
done_rows.append((title_tr, body_tr, lang_from, tr_id))
except Exception as e:
LOG.exception("Error traduciendo fila")
error_rows.append((str(e)[:1500], tr_id))
with conn.cursor() as cur:
if done_rows:
execute_values(
cur,
"""
UPDATE traducciones AS t
SET titulo_trad = v.titulo_trad,
resumen_trad = v.resumen_trad,
lang_from = COALESCE(t.lang_from, v.lang_from),
status = 'done',
error = NULL
FROM (VALUES %s) AS v(titulo_trad, resumen_trad, lang_from, id)
WHERE t.id = v.id;
""",
done_rows,
)
if error_rows:
execute_values(
cur,
"""
UPDATE traducciones AS t
SET status = 'error',
error = v.error
FROM (VALUES %s) AS v(error, id)
WHERE t.id = v.id;
""",
error_rows,
)
conn.commit()
def main():
LOG.info(
"Arrancando worker de traducción (NLLB). TARGET_LANGS=%s, BATCH=%s, ENQUEUE=%s, DEVICE=%s, "
"BEAMS(title/body)=%s/%s, CHUNK_BY_SENTENCES=%s, CHUNK_MAX_TOKENS=%s, OVERLAP_SENTS=%s",
TARGET_LANGS,
BATCH_SIZE,
ENQUEUE_MAX,
DEVICE_CFG,
NUM_BEAMS_TITLE,
NUM_BEAMS_BODY,
CHUNK_BY_SENTENCES,
CHUNK_MAX_TOKENS,
CHUNK_OVERLAP_SENTS,
)
get_universal_components()
while True:
any_work = False
with get_conn() as conn:
ensure_indexes(conn)
for lt in TARGET_LANGS:
lt = normalize_lang(lt, "es") or "es"
ensure_pending(conn, lt, ENQUEUE_MAX)
while True:
rows = fetch_pending_batch(conn, lt, BATCH_SIZE)
if not rows:
break
any_work = True
LOG.info("[%s] Procesando %d elementos…", lt, len(rows))
process_batch(conn, rows)
if not any_work:
time.sleep(SLEEP_IDLE)
if __name__ == "__main__":
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
main()