retoques
This commit is contained in:
parent
86ee083b90
commit
e3a99d9604
8 changed files with 489 additions and 483 deletions
|
|
@ -360,7 +360,7 @@ def _token_chunks(tokenizer, text: str, max_tokens: int) -> List[str]:
|
|||
return [text]
|
||||
chunks = []
|
||||
for i in range(0, len(ids), max_tokens):
|
||||
sub = ids[i : i + max_tokens]
|
||||
sub = ids[i: i + max_tokens]
|
||||
piece = tokenizer.decode(sub, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
if piece.strip():
|
||||
chunks.append(piece.strip())
|
||||
|
|
@ -413,6 +413,90 @@ def _forced_bos_id(tokenizer: AutoTokenizer, model: AutoModelForSeq2SeqLM, tgt_c
|
|||
return getattr(tokenizer, "eos_token_id", None) or getattr(tokenizer, "bos_token_id", None) or 0
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _translate_texts_simple(
|
||||
src_lang: str,
|
||||
tgt_lang: str,
|
||||
texts: List[str],
|
||||
num_beams: int = 1,
|
||||
_tries: int = 0,
|
||||
) -> List[str]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
cleaned = [(t or "").strip() for t in texts]
|
||||
if all(not t for t in cleaned):
|
||||
return ["" for _ in cleaned]
|
||||
|
||||
tok, mdl, device = get_universal_components()
|
||||
|
||||
src_code = map_to_nllb(src_lang) or "eng_Latn"
|
||||
tgt_code = map_to_nllb(tgt_lang) or "spa_Latn"
|
||||
|
||||
try:
|
||||
tok.src_lang = src_code
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
forced_bos = _forced_bos_id(tok, mdl, tgt_code)
|
||||
safe_len = _safe_src_len(tok)
|
||||
|
||||
try:
|
||||
autocast_ctx = (
|
||||
torch.amp.autocast("cuda", dtype=torch.float16)
|
||||
if device.type == "cuda"
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
|
||||
enc = tok(
|
||||
cleaned,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=safe_len,
|
||||
)
|
||||
enc = {k: v.to(device) for k, v in enc.items()}
|
||||
|
||||
gen_kwargs = dict(
|
||||
forced_bos_token_id=forced_bos,
|
||||
max_new_tokens=MAX_NEW_TOKENS,
|
||||
num_beams=max(1, int(num_beams)),
|
||||
do_sample=False,
|
||||
use_cache=False,
|
||||
)
|
||||
if int(num_beams) > 1:
|
||||
gen_kwargs["early_stopping"] = True
|
||||
|
||||
with autocast_ctx:
|
||||
generated = mdl.generate(**enc, **gen_kwargs)
|
||||
|
||||
outs = tok.batch_decode(generated, skip_special_tokens=True)
|
||||
outs = [o.strip() for o in outs]
|
||||
|
||||
del enc, generated
|
||||
if device.type == "cuda":
|
||||
_free_cuda()
|
||||
|
||||
return outs
|
||||
|
||||
except Exception as e:
|
||||
if device.type == "cuda" and _is_cuda_mem_error(e) and _tries < 2:
|
||||
LOG.warning("CUDA OOM/allocator (batch): intento de recuperación %d. Detalle: %s", _tries + 1, e)
|
||||
global _MODEL, _DEVICE, _CUDA_DISABLED
|
||||
_CUDA_DISABLED = True
|
||||
try:
|
||||
if _MODEL is not None:
|
||||
del _MODEL
|
||||
except Exception:
|
||||
pass
|
||||
_free_cuda()
|
||||
_MODEL = None
|
||||
_DEVICE = None
|
||||
time.sleep(1.0)
|
||||
return _translate_texts_simple(src_lang, tgt_lang, texts, num_beams=num_beams, _tries=_tries + 1)
|
||||
raise
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def translate_text(src_lang: str, tgt_lang: str, text: str, num_beams: int = 1, _tries: int = 0) -> str:
|
||||
if not text or not text.strip():
|
||||
|
|
@ -495,7 +579,7 @@ def _pack_sentences_to_token_chunks(
|
|||
ids = tokenizer(s, add_special_tokens=False).input_ids
|
||||
step = max_tokens
|
||||
for i in range(0, len(ids), step):
|
||||
sub = tokenizer.decode(ids[i : i + step], skip_special_tokens=True)
|
||||
sub = tokenizer.decode(ids[i: i + step], skip_special_tokens=True)
|
||||
if cur:
|
||||
chunks.append(cur)
|
||||
cur = []
|
||||
|
|
@ -536,6 +620,75 @@ def _smart_concatenate(parts: List[str], tail_window: int = 120) -> str:
|
|||
return out
|
||||
|
||||
|
||||
def translate_articles_full_batch(
|
||||
src_lang: str,
|
||||
tgt_lang: str,
|
||||
texts: List[str],
|
||||
num_beams: int,
|
||||
) -> List[str]:
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
if not CHUNK_BY_SENTENCES:
|
||||
return _translate_texts_simple(src_lang, tgt_lang, texts, num_beams=num_beams)
|
||||
|
||||
tok, _, _ = get_universal_components()
|
||||
safe_len = _safe_src_len(tok)
|
||||
max_chunk_tokens = min(CHUNK_MAX_TOKENS, safe_len)
|
||||
|
||||
all_chunk_texts: List[str] = []
|
||||
per_article_chunk_ids: List[List[int]] = []
|
||||
|
||||
for text in texts:
|
||||
text = (text or "").strip()
|
||||
if not text:
|
||||
per_article_chunk_ids.append([])
|
||||
continue
|
||||
|
||||
sents = split_into_sentences(text)
|
||||
if not sents:
|
||||
per_article_chunk_ids.append([])
|
||||
continue
|
||||
|
||||
chunks_sents = _pack_sentences_to_token_chunks(
|
||||
tok,
|
||||
sents,
|
||||
max_tokens=max_chunk_tokens,
|
||||
overlap_sents=CHUNK_OVERLAP_SENTS,
|
||||
)
|
||||
|
||||
ids_for_this_article: List[int] = []
|
||||
for group in chunks_sents:
|
||||
chunk_text = " ".join(group).strip()
|
||||
if not chunk_text:
|
||||
continue
|
||||
idx = len(all_chunk_texts)
|
||||
all_chunk_texts.append(chunk_text)
|
||||
ids_for_this_article.append(idx)
|
||||
|
||||
per_article_chunk_ids.append(ids_for_this_article)
|
||||
|
||||
if not all_chunk_texts:
|
||||
return ["" for _ in texts]
|
||||
|
||||
translated_chunks = _translate_texts_simple(
|
||||
src_lang,
|
||||
tgt_lang,
|
||||
all_chunk_texts,
|
||||
num_beams=num_beams,
|
||||
)
|
||||
|
||||
outs: List[str] = []
|
||||
for chunk_ids in per_article_chunk_ids:
|
||||
if not chunk_ids:
|
||||
outs.append("")
|
||||
continue
|
||||
parts = [translated_chunks[i] for i in chunk_ids]
|
||||
outs.append(_smart_concatenate([p for p in parts if p]))
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
def translate_article_full(
|
||||
src_lang: str,
|
||||
tgt_lang: str,
|
||||
|
|
@ -570,9 +723,15 @@ def translate_article_full(
|
|||
|
||||
|
||||
def process_batch(conn, rows):
|
||||
batch_size = len(rows)
|
||||
LOG.info("Iniciando traducción de batch con %d filas…", batch_size)
|
||||
t0 = time.time()
|
||||
|
||||
done_rows = []
|
||||
error_rows = []
|
||||
|
||||
enriched_rows = []
|
||||
|
||||
for r in rows:
|
||||
tr_id = r["tr_id"]
|
||||
lang_to = normalize_lang(r["lang_to"], "es") or "es"
|
||||
|
|
@ -581,23 +740,54 @@ def process_batch(conn, rows):
|
|||
title = (r["titulo"] or "").strip()
|
||||
body = (r["resumen"] or "").strip()
|
||||
|
||||
if (map_to_nllb(lang_from) or "eng_Latn") == (map_to_nllb(lang_to) or "spa_Latn"):
|
||||
src_code = map_to_nllb(lang_from) or "eng_Latn"
|
||||
tgt_code = map_to_nllb(lang_to) or "spa_Latn"
|
||||
|
||||
if src_code == tgt_code:
|
||||
done_rows.append((title, body, lang_from, tr_id))
|
||||
continue
|
||||
|
||||
enriched_rows.append(
|
||||
{
|
||||
"tr_id": tr_id,
|
||||
"lang_from": lang_from,
|
||||
"lang_to": lang_to,
|
||||
"title": title,
|
||||
"body": body,
|
||||
}
|
||||
)
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
groups = defaultdict(list)
|
||||
for er in enriched_rows:
|
||||
key = (er["lang_from"], er["lang_to"])
|
||||
groups[key].append(er)
|
||||
|
||||
for (lang_from, lang_to), items in groups.items():
|
||||
titles = [it["title"] for it in items]
|
||||
bodies = [it["body"] for it in items]
|
||||
|
||||
try:
|
||||
title_tr = translate_text(lang_from, lang_to, title, num_beams=NUM_BEAMS_TITLE) if title else ""
|
||||
body_tr = translate_article_full(lang_from, lang_to, body, num_beams=NUM_BEAMS_BODY) if body else ""
|
||||
titles_tr = _translate_texts_simple(lang_from, lang_to, titles, num_beams=NUM_BEAMS_TITLE)
|
||||
bodies_tr = translate_articles_full_batch(lang_from, lang_to, bodies, num_beams=NUM_BEAMS_BODY)
|
||||
|
||||
if _norm(title_tr) == _norm(title):
|
||||
title_tr = ""
|
||||
if _norm(body_tr) == _norm(body):
|
||||
body_tr = ""
|
||||
for it, t_tr, b_tr in zip(items, titles_tr, bodies_tr):
|
||||
title_orig = it["title"]
|
||||
body_orig = it["body"]
|
||||
|
||||
if _norm(t_tr) == _norm(title_orig):
|
||||
t_tr = ""
|
||||
if _norm(b_tr) == _norm(body_orig):
|
||||
b_tr = ""
|
||||
|
||||
done_rows.append((t_tr, b_tr, lang_from, it["tr_id"]))
|
||||
|
||||
done_rows.append((title_tr, body_tr, lang_from, tr_id))
|
||||
except Exception as e:
|
||||
LOG.exception("Error traduciendo fila")
|
||||
error_rows.append((str(e)[:1500], tr_id))
|
||||
LOG.exception("Error traduciendo lote %s -> %s", lang_from, lang_to)
|
||||
err_msg = str(e)[:1500]
|
||||
for it in items:
|
||||
error_rows.append((err_msg, it["tr_id"]))
|
||||
|
||||
with conn.cursor() as cur:
|
||||
if done_rows:
|
||||
|
|
@ -630,6 +820,28 @@ def process_batch(conn, rows):
|
|||
)
|
||||
conn.commit()
|
||||
|
||||
dt = time.time() - t0
|
||||
try:
|
||||
_, _, device = get_universal_components()
|
||||
dev_label = device.type.upper() if device is not None else "UNK"
|
||||
except Exception:
|
||||
dev_label = "UNK"
|
||||
|
||||
if batch_size > 0:
|
||||
LOG.info(
|
||||
"[%s] Batch de %d filas traducido en %.2f s (%.2f s/noticia)",
|
||||
dev_label,
|
||||
batch_size,
|
||||
dt,
|
||||
dt / batch_size,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
"[%s] Batch vacío, nada que traducir (%.2f s)",
|
||||
dev_label,
|
||||
dt,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
LOG.info(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue