package main import ( "context" "log" "os" "os/signal" "strconv" "syscall" "time" "github.com/jackc/pgx/v5/pgxpool" "github.com/rss2/backend/internal/workers" ) var ( logger *log.Logger dbPool *pgxpool.Pool sleepSec = 10 topK = 10 batchSz = 200 minScore = 0.0 ) func init() { logger = log.New(os.Stdout, "[RELATED] ", log.LstdFlags) } func loadConfig() { sleepSec = getEnvInt("RELATED_SLEEP", 10) topK = getEnvInt("RELATED_TOPK", 10) batchSz = getEnvInt("RELATED_BATCH", 200) minScore = getEnvFloat("RELATED_MIN_SCORE", 0.0) } func getEnvInt(key string, defaultValue int) int { if value := os.Getenv(key); value != "" { if intVal, err := strconv.Atoi(value); err == nil { return intVal } } return defaultValue } func getEnvFloat(key string, defaultValue float64) float64 { if value := os.Getenv(key); value != "" { if floatVal, err := strconv.ParseFloat(value, 64); err == nil { return floatVal } } return defaultValue } type Translation struct { ID int64 Titulo string Resumen string Embedding []float64 } func ensureSchema(ctx context.Context) error { _, err := dbPool.Exec(ctx, ` CREATE TABLE IF NOT EXISTS related_noticias ( traduccion_id INTEGER REFERENCES traducciones(id) ON DELETE CASCADE, related_traduccion_id INTEGER REFERENCES traducciones(id) ON DELETE CASCADE, score FLOAT NOT NULL DEFAULT 0, created_at TIMESTAMP DEFAULT NOW(), PRIMARY KEY (traduccion_id, related_traduccion_id) ); `) if err != nil { return err } // Ensure traduccion_embeddings table exists _, err = dbPool.Exec(ctx, ` CREATE TABLE IF NOT EXISTS traduccion_embeddings ( id SERIAL PRIMARY KEY, traduccion_id INTEGER NOT NULL REFERENCES traducciones(id) ON DELETE CASCADE, model TEXT NOT NULL, dim INTEGER NOT NULL, embedding DOUBLE PRECISION[] NOT NULL, created_at TIMESTAMP DEFAULT NOW(), UNIQUE (traduccion_id, model) ); `) if err != nil { return err } _, err = dbPool.Exec(ctx, ` CREATE INDEX IF NOT EXISTS idx_tr_emb_model ON traduccion_embeddings(model); `) if err != nil { return err } _, err = dbPool.Exec(ctx, ` CREATE INDEX IF NOT EXISTS idx_tr_emb_traduccion_id ON traduccion_embeddings(traduccion_id); `) return err } func fetchAllEmbeddings(ctx context.Context, model string) ([]Translation, error) { rows, err := dbPool.Query(ctx, ` SELECT e.traduccion_id, COALESCE(NULLIF(t.titulo_trad,''), ''), COALESCE(NULLIF(t.resumen_trad,''), ''), e.embedding FROM traduccion_embeddings e JOIN traducciones t ON t.id = e.traduccion_id WHERE e.model = $1 AND t.status = 'done' AND t.lang_to = 'es' `, model) if err != nil { return nil, err } defer rows.Close() var translations []Translation for rows.Next() { var t Translation if err := rows.Scan(&t.ID, &t.Titulo, &t.Resumen, &t.Embedding); err != nil { continue } translations = append(translations, t) } return translations, nil } func fetchPendingIDs(ctx context.Context, model string, limit int) ([]int64, error) { rows, err := dbPool.Query(ctx, ` SELECT t.id FROM traducciones t JOIN traduccion_embeddings e ON e.traduccion_id = t.id AND e.model = $1 LEFT JOIN related_noticias r ON r.traduccion_id = t.id WHERE t.lang_to = 'es' AND t.status = 'done' GROUP BY t.id HAVING COUNT(r.related_traduccion_id) = 0 ORDER BY t.id DESC LIMIT $2 `, model, limit) if err != nil { return nil, err } defer rows.Close() var ids []int64 for rows.Next() { var id int64 if err := rows.Scan(&id); err != nil { continue } ids = append(ids, id) } return ids, nil } func cosineSimilarity(a, b []float64) float64 { if len(a) != len(b) || len(a) == 0 { return 0 } var dotProduct, normA, normB float64 for i := range a { dotProduct += a[i] * b[i] normA += a[i] * a[i] normB += b[i] * b[i] } normA = sqrt(normA) normB = sqrt(normB) if normA == 0 || normB == 0 { return 0 } return dotProduct / (normA * normB) } func sqrt(x float64) float64 { if x <= 0 { return 0 } // Simple Newton-Raphson z := x for i := 0; i < 20; i++ { z = (z + x/z) / 2 } return z } func findTopK(query Embedding, candidates []Translation, k int, minScore float64) []struct { ID int64 Score float64 } { type sim struct { id int64 score float64 } var similarities []sim for _, c := range candidates { if int64(c.ID) == query.ID { continue } score := cosineSimilarity(query.Embedding, c.Embedding) if score <= minScore { continue } similarities = append(similarities, sim{int64(c.ID), score}) } // Sort by score descending for i := 0; i < len(similarities)-1; i++ { for j := i + 1; j < len(similarities); j++ { if similarities[j].score > similarities[i].score { similarities[i], similarities[j] = similarities[j], similarities[i] } } } if len(similarities) > k { similarities = similarities[:k] } result := make([]struct { ID int64 Score float64 }, len(similarities)) for i, s := range similarities { result[i] = struct { ID int64 Score float64 }{s.id, s.score} } return result } type Embedding struct { ID int64 Embedding []float64 } func findEmbeddingByID(embeddings []Embedding, id int64) *Embedding { for i := range embeddings { if embeddings[i].ID == id { return &embeddings[i] } } return nil } func insertRelated(ctx context.Context, traduccionID int64, related []struct { ID int64 Score float64 }) error { if len(related) == 0 { return nil } for _, r := range related { if r.Score <= 0 { continue } _, err := dbPool.Exec(ctx, ` INSERT INTO related_noticias (traduccion_id, related_traduccion_id, score) VALUES ($1, $2, $3) ON CONFLICT (traduccion_id, related_traduccion_id) DO UPDATE SET score = EXCLUDED.score `, traduccionID, r.ID, r.Score) if err != nil { logger.Printf("Error inserting related: %v", err) } } return nil } func processBatch(ctx context.Context, model string) (int, error) { // Fetch all embeddings once allTranslations, err := fetchAllEmbeddings(ctx, model) if err != nil { return 0, err } if len(allTranslations) == 0 { return 0, nil } // Convert to Embedding format for easier lookup var allEmbeddings []Embedding for _, t := range allTranslations { if t.Embedding != nil { allEmbeddings = append(allEmbeddings, Embedding{ID: t.ID, Embedding: t.Embedding}) } } // Get pending IDs pendingIDs, err := fetchPendingIDs(ctx, model, batchSz) if err != nil { return 0, err } if len(pendingIDs) == 0 { return 0, nil } processed := 0 for _, tradID := range pendingIDs { emb := findEmbeddingByID(allEmbeddings, tradID) if emb == nil { continue } topRelated := findTopK(*emb, allTranslations, topK, minScore) if err := insertRelated(ctx, tradID, topRelated); err != nil { logger.Printf("Error inserting related for %d: %v", tradID, err) continue } processed++ } return processed, nil } func main() { loadConfig() logger.Println("Starting Related News Worker") cfg := workers.LoadDBConfig() if err := workers.Connect(cfg); err != nil { logger.Fatalf("Failed to connect to database: %v", err) } dbPool = workers.GetPool() defer workers.Close() ctx := context.Background() // Ensure schema if err := ensureSchema(ctx); err != nil { logger.Printf("Error ensuring schema: %v", err) } model := os.Getenv("EMB_MODEL") if model == "" { model = "mxbai-embed-large" } sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) go func() { <-sigChan logger.Println("Shutting down...") os.Exit(0) }() logger.Printf("Config: sleep=%ds, topK=%d, batch=%d, model=%s", sleepSec, topK, batchSz, model) for { select { case <-time.After(time.Duration(sleepSec) * time.Second): count, err := processBatch(ctx, model) if err != nil { logger.Printf("Error processing batch: %v", err) continue } if count > 0 { logger.Printf("Generated related news for %d translations", count) } } } }