391 lines
8.8 KiB
Go
391 lines
8.8 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"strconv"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"github.com/rss2/backend/internal/workers"
|
|
)
|
|
|
|
var (
|
|
logger *log.Logger
|
|
dbPool *pgxpool.Pool
|
|
qdrantURL string
|
|
ollamaURL string
|
|
collection = "news_vectors"
|
|
sleepSec = 30
|
|
batchSize = 100
|
|
)
|
|
|
|
func init() {
|
|
logger = log.New(os.Stdout, "[QDRANT] ", log.LstdFlags)
|
|
}
|
|
|
|
func loadConfig() {
|
|
sleepSec = getEnvInt("QDRANT_SLEEP", 30)
|
|
batchSize = getEnvInt("QDRANT_BATCH", 100)
|
|
qdrantHost := getEnv("QDRANT_HOST", "localhost")
|
|
qdrantPort := getEnvInt("QDRANT_PORT", 6333)
|
|
qdrantURL = fmt.Sprintf("http://%s:%d", qdrantHost, qdrantPort)
|
|
ollamaURL = getEnv("OLLAMA_URL", "http://ollama:11434")
|
|
collection = getEnv("QDRANT_COLLECTION", "news_vectors")
|
|
}
|
|
|
|
func getEnv(key, defaultValue string) string {
|
|
if value := os.Getenv(key); value != "" {
|
|
return value
|
|
}
|
|
return defaultValue
|
|
}
|
|
|
|
func getEnvInt(key string, defaultValue int) int {
|
|
if value := os.Getenv(key); value != "" {
|
|
if intVal, err := strconv.Atoi(value); err == nil {
|
|
return intVal
|
|
}
|
|
}
|
|
return defaultValue
|
|
}
|
|
|
|
type Translation struct {
|
|
ID int64
|
|
NoticiaID int64
|
|
Lang string
|
|
Titulo string
|
|
Resumen string
|
|
URL string
|
|
Fecha *time.Time
|
|
FuenteNombre string
|
|
CategoriaID *int64
|
|
PaisID *int64
|
|
}
|
|
|
|
func getPendingTranslations(ctx context.Context) ([]Translation, error) {
|
|
rows, err := dbPool.Query(ctx, `
|
|
SELECT
|
|
t.id as traduccion_id,
|
|
t.noticia_id,
|
|
TRIM(t.lang_to) as lang,
|
|
t.titulo_trad as titulo,
|
|
t.resumen_trad as resumen,
|
|
n.url,
|
|
n.fecha,
|
|
n.fuente_nombre,
|
|
n.categoria_id,
|
|
n.pais_id
|
|
FROM traducciones t
|
|
INNER JOIN noticias n ON t.noticia_id = n.id
|
|
WHERE t.vectorized = FALSE
|
|
AND t.status = 'done'
|
|
ORDER BY t.created_at ASC
|
|
LIMIT $1
|
|
`, batchSize)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var translations []Translation
|
|
for rows.Next() {
|
|
var t Translation
|
|
if err := rows.Scan(
|
|
&t.ID, &t.NoticiaID, &t.Lang, &t.Titulo, &t.Resumen,
|
|
&t.URL, &t.Fecha, &t.FuenteNombre, &t.CategoriaID, &t.PaisID,
|
|
); err != nil {
|
|
continue
|
|
}
|
|
translations = append(translations, t)
|
|
}
|
|
return translations, nil
|
|
}
|
|
|
|
type EmbeddingRequest struct {
|
|
Model string `json:"model"`
|
|
Input string `json:"input"`
|
|
}
|
|
|
|
type EmbeddingResponse struct {
|
|
Embedding []float64 `json:"embedding"`
|
|
}
|
|
|
|
func generateEmbedding(text string) ([]float64, error) {
|
|
reqBody := EmbeddingRequest{
|
|
Model: "mxbai-embed-large",
|
|
Input: text,
|
|
}
|
|
|
|
body, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
client := &http.Client{Timeout: 60 * time.Second}
|
|
resp, err := client.Post(ollamaURL+"/api/embeddings", "application/json", bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("Ollama returned status %d", resp.StatusCode)
|
|
}
|
|
|
|
var result EmbeddingResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return result.Embedding, nil
|
|
}
|
|
|
|
type QdrantPoint struct {
|
|
ID interface{} `json:"id"`
|
|
Vector []float64 `json:"vector"`
|
|
Payload map[string]interface{} `json:"payload"`
|
|
}
|
|
|
|
type QdrantUpsertRequest struct {
|
|
Points []QdrantPoint `json:"points"`
|
|
}
|
|
|
|
func ensureCollection() error {
|
|
req, err := http.NewRequest("GET", qdrantURL+"/collections/"+collection, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode == 200 {
|
|
logger.Printf("Collection %s already exists", collection)
|
|
return nil
|
|
}
|
|
|
|
// Get embedding dimension
|
|
emb, err := generateEmbedding("test")
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get embedding dimension: %w", err)
|
|
}
|
|
dimension := len(emb)
|
|
|
|
// Create collection
|
|
createReq := map[string]interface{}{
|
|
"name": collection,
|
|
"vectors": map[string]interface{}{
|
|
"size": dimension,
|
|
"distance": "Cosine",
|
|
},
|
|
}
|
|
|
|
body, _ := json.Marshal(createReq)
|
|
resp2, err := http.Post(qdrantURL+"/collections", "application/json", bytes.NewReader(body))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp2.Body.Close()
|
|
|
|
logger.Printf("Created collection %s with dimension %d", collection, dimension)
|
|
return nil
|
|
}
|
|
|
|
func uploadToQdrant(translations []Translation, embeddings [][]float64) error {
|
|
points := make([]QdrantPoint, 0, len(translations))
|
|
|
|
for i, t := range translations {
|
|
if embeddings[i] == nil {
|
|
continue
|
|
}
|
|
|
|
pointID := uuid.New().String()
|
|
|
|
payload := map[string]interface{}{
|
|
"news_id": t.NoticiaID,
|
|
"traduccion_id": t.ID,
|
|
"titulo": t.Titulo,
|
|
"resumen": t.Resumen,
|
|
"url": t.URL,
|
|
"fuente_nombre": t.FuenteNombre,
|
|
"lang": t.Lang,
|
|
}
|
|
|
|
if t.Fecha != nil {
|
|
payload["fecha"] = t.Fecha.Format(time.RFC3339)
|
|
}
|
|
if t.CategoriaID != nil {
|
|
payload["categoria_id"] = *t.CategoriaID
|
|
}
|
|
if t.PaisID != nil {
|
|
payload["pais_id"] = *t.PaisID
|
|
}
|
|
|
|
points = append(points, QdrantPoint{
|
|
ID: pointID,
|
|
Vector: embeddings[i],
|
|
Payload: payload,
|
|
})
|
|
}
|
|
|
|
if len(points) == 0 {
|
|
return nil
|
|
}
|
|
|
|
reqBody := QdrantUpsertRequest{Points: points}
|
|
body, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
url := fmt.Sprintf("%s/collections/%s/points", qdrantURL, collection)
|
|
resp, err := http.Post(url, "application/json", bytes.NewReader(body))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 && resp.StatusCode != 202 {
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
return fmt.Errorf("Qdrant returned status %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func updateTranslationStatus(ctx context.Context, translations []Translation, pointIDs []string) error {
|
|
for i, t := range translations {
|
|
if i >= len(pointIDs) || pointIDs[i] == "" {
|
|
continue
|
|
}
|
|
|
|
_, err := dbPool.Exec(ctx, `
|
|
UPDATE traducciones
|
|
SET
|
|
vectorized = TRUE,
|
|
vectorization_date = NOW(),
|
|
qdrant_point_id = $1
|
|
WHERE id = $2
|
|
`, pointIDs[i], t.ID)
|
|
|
|
if err != nil {
|
|
logger.Printf("Error updating translation %d: %v", t.ID, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getStats(ctx context.Context) (total, vectorized, pending int, err error) {
|
|
err = dbPool.QueryRow(ctx, `
|
|
SELECT
|
|
COUNT(*) as total,
|
|
COUNT(*) FILTER (WHERE vectorized = TRUE) as vectorized,
|
|
COUNT(*) FILTER (WHERE vectorized = FALSE AND status = 'done') as pending
|
|
FROM traducciones
|
|
WHERE lang_to = 'es'
|
|
`).Scan(&total, &vectorized, &pending)
|
|
|
|
return total, vectorized, pending, err
|
|
}
|
|
|
|
func main() {
|
|
loadConfig()
|
|
logger.Println("Starting Qdrant Vectorization Worker")
|
|
|
|
cfg := workers.LoadDBConfig()
|
|
if err := workers.Connect(cfg); err != nil {
|
|
logger.Fatalf("Failed to connect to database: %v", err)
|
|
}
|
|
dbPool = workers.GetPool()
|
|
defer workers.Close()
|
|
|
|
logger.Println("Connected to PostgreSQL")
|
|
|
|
ctx := context.Background()
|
|
|
|
if err := ensureCollection(); err != nil {
|
|
logger.Printf("Warning: Could not ensure collection: %v", err)
|
|
}
|
|
|
|
sigChan := make(chan os.Signal, 1)
|
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
|
|
go func() {
|
|
<-sigChan
|
|
logger.Println("Shutting down...")
|
|
os.Exit(0)
|
|
}()
|
|
|
|
logger.Printf("Config: qdrant=%s, ollama=%s, collection=%s, sleep=%ds, batch=%d",
|
|
qdrantURL, ollamaURL, collection, sleepSec, batchSize)
|
|
|
|
totalProcessed := 0
|
|
|
|
for {
|
|
select {
|
|
case <-time.After(time.Duration(sleepSec) * time.Second):
|
|
translations, err := getPendingTranslations(ctx)
|
|
if err != nil {
|
|
logger.Printf("Error fetching pending translations: %v", err)
|
|
continue
|
|
}
|
|
|
|
if len(translations) == 0 {
|
|
logger.Println("No pending translations to process")
|
|
continue
|
|
}
|
|
|
|
logger.Printf("Processing %d translations...", len(translations))
|
|
|
|
// Generate embeddings
|
|
embeddings := make([][]float64, len(translations))
|
|
for i, t := range translations {
|
|
text := fmt.Sprintf("%s %s", t.Titulo, t.Resumen)
|
|
emb, err := generateEmbedding(text)
|
|
if err != nil {
|
|
logger.Printf("Error generating embedding for %d: %v", t.ID, err)
|
|
continue
|
|
}
|
|
embeddings[i] = emb
|
|
}
|
|
|
|
// Upload to Qdrant
|
|
if err := uploadToQdrant(translations, embeddings); err != nil {
|
|
logger.Printf("Error uploading to Qdrant: %v", err)
|
|
continue
|
|
}
|
|
|
|
// Update DB status
|
|
pointIDs := make([]string, len(translations))
|
|
for i := range translations {
|
|
pointIDs[i] = uuid.New().String()
|
|
}
|
|
|
|
if err := updateTranslationStatus(ctx, translations, pointIDs); err != nil {
|
|
logger.Printf("Error updating status: %v", err)
|
|
}
|
|
|
|
totalProcessed += len(translations)
|
|
logger.Printf("Processed %d translations (total: %d)", len(translations), totalProcessed)
|
|
|
|
total, vectorized, pending, err := getStats(ctx)
|
|
if err == nil {
|
|
logger.Printf("Stats: total=%d, vectorized=%d, pending=%d", total, vectorized, pending)
|
|
}
|
|
}
|
|
}
|
|
}
|