This commit is contained in:
jlimolina 2025-11-24 23:06:26 +01:00
parent 86ee083b90
commit e3a99d9604
8 changed files with 489 additions and 483 deletions

View file

@ -360,7 +360,7 @@ def _token_chunks(tokenizer, text: str, max_tokens: int) -> List[str]:
return [text]
chunks = []
for i in range(0, len(ids), max_tokens):
sub = ids[i : i + max_tokens]
sub = ids[i: i + max_tokens]
piece = tokenizer.decode(sub, skip_special_tokens=True, clean_up_tokenization_spaces=True)
if piece.strip():
chunks.append(piece.strip())
@ -413,6 +413,90 @@ def _forced_bos_id(tokenizer: AutoTokenizer, model: AutoModelForSeq2SeqLM, tgt_c
return getattr(tokenizer, "eos_token_id", None) or getattr(tokenizer, "bos_token_id", None) or 0
@torch.inference_mode()
def _translate_texts_simple(
src_lang: str,
tgt_lang: str,
texts: List[str],
num_beams: int = 1,
_tries: int = 0,
) -> List[str]:
if not texts:
return []
cleaned = [(t or "").strip() for t in texts]
if all(not t for t in cleaned):
return ["" for _ in cleaned]
tok, mdl, device = get_universal_components()
src_code = map_to_nllb(src_lang) or "eng_Latn"
tgt_code = map_to_nllb(tgt_lang) or "spa_Latn"
try:
tok.src_lang = src_code
except Exception:
pass
forced_bos = _forced_bos_id(tok, mdl, tgt_code)
safe_len = _safe_src_len(tok)
try:
autocast_ctx = (
torch.amp.autocast("cuda", dtype=torch.float16)
if device.type == "cuda"
else contextlib.nullcontext()
)
enc = tok(
cleaned,
return_tensors="pt",
padding=True,
truncation=True,
max_length=safe_len,
)
enc = {k: v.to(device) for k, v in enc.items()}
gen_kwargs = dict(
forced_bos_token_id=forced_bos,
max_new_tokens=MAX_NEW_TOKENS,
num_beams=max(1, int(num_beams)),
do_sample=False,
use_cache=False,
)
if int(num_beams) > 1:
gen_kwargs["early_stopping"] = True
with autocast_ctx:
generated = mdl.generate(**enc, **gen_kwargs)
outs = tok.batch_decode(generated, skip_special_tokens=True)
outs = [o.strip() for o in outs]
del enc, generated
if device.type == "cuda":
_free_cuda()
return outs
except Exception as e:
if device.type == "cuda" and _is_cuda_mem_error(e) and _tries < 2:
LOG.warning("CUDA OOM/allocator (batch): intento de recuperación %d. Detalle: %s", _tries + 1, e)
global _MODEL, _DEVICE, _CUDA_DISABLED
_CUDA_DISABLED = True
try:
if _MODEL is not None:
del _MODEL
except Exception:
pass
_free_cuda()
_MODEL = None
_DEVICE = None
time.sleep(1.0)
return _translate_texts_simple(src_lang, tgt_lang, texts, num_beams=num_beams, _tries=_tries + 1)
raise
@torch.inference_mode()
def translate_text(src_lang: str, tgt_lang: str, text: str, num_beams: int = 1, _tries: int = 0) -> str:
if not text or not text.strip():
@ -495,7 +579,7 @@ def _pack_sentences_to_token_chunks(
ids = tokenizer(s, add_special_tokens=False).input_ids
step = max_tokens
for i in range(0, len(ids), step):
sub = tokenizer.decode(ids[i : i + step], skip_special_tokens=True)
sub = tokenizer.decode(ids[i: i + step], skip_special_tokens=True)
if cur:
chunks.append(cur)
cur = []
@ -536,6 +620,75 @@ def _smart_concatenate(parts: List[str], tail_window: int = 120) -> str:
return out
def translate_articles_full_batch(
src_lang: str,
tgt_lang: str,
texts: List[str],
num_beams: int,
) -> List[str]:
if not texts:
return []
if not CHUNK_BY_SENTENCES:
return _translate_texts_simple(src_lang, tgt_lang, texts, num_beams=num_beams)
tok, _, _ = get_universal_components()
safe_len = _safe_src_len(tok)
max_chunk_tokens = min(CHUNK_MAX_TOKENS, safe_len)
all_chunk_texts: List[str] = []
per_article_chunk_ids: List[List[int]] = []
for text in texts:
text = (text or "").strip()
if not text:
per_article_chunk_ids.append([])
continue
sents = split_into_sentences(text)
if not sents:
per_article_chunk_ids.append([])
continue
chunks_sents = _pack_sentences_to_token_chunks(
tok,
sents,
max_tokens=max_chunk_tokens,
overlap_sents=CHUNK_OVERLAP_SENTS,
)
ids_for_this_article: List[int] = []
for group in chunks_sents:
chunk_text = " ".join(group).strip()
if not chunk_text:
continue
idx = len(all_chunk_texts)
all_chunk_texts.append(chunk_text)
ids_for_this_article.append(idx)
per_article_chunk_ids.append(ids_for_this_article)
if not all_chunk_texts:
return ["" for _ in texts]
translated_chunks = _translate_texts_simple(
src_lang,
tgt_lang,
all_chunk_texts,
num_beams=num_beams,
)
outs: List[str] = []
for chunk_ids in per_article_chunk_ids:
if not chunk_ids:
outs.append("")
continue
parts = [translated_chunks[i] for i in chunk_ids]
outs.append(_smart_concatenate([p for p in parts if p]))
return outs
def translate_article_full(
src_lang: str,
tgt_lang: str,
@ -570,9 +723,15 @@ def translate_article_full(
def process_batch(conn, rows):
batch_size = len(rows)
LOG.info("Iniciando traducción de batch con %d filas…", batch_size)
t0 = time.time()
done_rows = []
error_rows = []
enriched_rows = []
for r in rows:
tr_id = r["tr_id"]
lang_to = normalize_lang(r["lang_to"], "es") or "es"
@ -581,23 +740,54 @@ def process_batch(conn, rows):
title = (r["titulo"] or "").strip()
body = (r["resumen"] or "").strip()
if (map_to_nllb(lang_from) or "eng_Latn") == (map_to_nllb(lang_to) or "spa_Latn"):
src_code = map_to_nllb(lang_from) or "eng_Latn"
tgt_code = map_to_nllb(lang_to) or "spa_Latn"
if src_code == tgt_code:
done_rows.append((title, body, lang_from, tr_id))
continue
enriched_rows.append(
{
"tr_id": tr_id,
"lang_from": lang_from,
"lang_to": lang_to,
"title": title,
"body": body,
}
)
from collections import defaultdict
groups = defaultdict(list)
for er in enriched_rows:
key = (er["lang_from"], er["lang_to"])
groups[key].append(er)
for (lang_from, lang_to), items in groups.items():
titles = [it["title"] for it in items]
bodies = [it["body"] for it in items]
try:
title_tr = translate_text(lang_from, lang_to, title, num_beams=NUM_BEAMS_TITLE) if title else ""
body_tr = translate_article_full(lang_from, lang_to, body, num_beams=NUM_BEAMS_BODY) if body else ""
titles_tr = _translate_texts_simple(lang_from, lang_to, titles, num_beams=NUM_BEAMS_TITLE)
bodies_tr = translate_articles_full_batch(lang_from, lang_to, bodies, num_beams=NUM_BEAMS_BODY)
if _norm(title_tr) == _norm(title):
title_tr = ""
if _norm(body_tr) == _norm(body):
body_tr = ""
for it, t_tr, b_tr in zip(items, titles_tr, bodies_tr):
title_orig = it["title"]
body_orig = it["body"]
if _norm(t_tr) == _norm(title_orig):
t_tr = ""
if _norm(b_tr) == _norm(body_orig):
b_tr = ""
done_rows.append((t_tr, b_tr, lang_from, it["tr_id"]))
done_rows.append((title_tr, body_tr, lang_from, tr_id))
except Exception as e:
LOG.exception("Error traduciendo fila")
error_rows.append((str(e)[:1500], tr_id))
LOG.exception("Error traduciendo lote %s -> %s", lang_from, lang_to)
err_msg = str(e)[:1500]
for it in items:
error_rows.append((err_msg, it["tr_id"]))
with conn.cursor() as cur:
if done_rows:
@ -630,6 +820,28 @@ def process_batch(conn, rows):
)
conn.commit()
dt = time.time() - t0
try:
_, _, device = get_universal_components()
dev_label = device.type.upper() if device is not None else "UNK"
except Exception:
dev_label = "UNK"
if batch_size > 0:
LOG.info(
"[%s] Batch de %d filas traducido en %.2f s (%.2f s/noticia)",
dev_label,
batch_size,
dt,
dt / batch_size,
)
else:
LOG.info(
"[%s] Batch vacío, nada que traducir (%.2f s)",
dev_label,
dt,
)
def main():
LOG.info(