/*
 * Decompiled with CFR 0.152.
 */
package schemacrawler.tools.command.chatgpt.embeddings;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.math3.linear.RealVector;
import schemacrawler.tools.command.chatgpt.embeddings.EmbeddedTable;
import schemacrawler.tools.command.chatgpt.embeddings.EmbeddingService;
import schemacrawler.tools.command.chatgpt.embeddings.TableSimilarity;
import schemacrawler.tools.command.chatgpt.embeddings.TextEmbedding;
import us.fatehi.utility.Utility;
import us.fatehi.utility.string.StringFormat;

public final class TableSimilarityService {
    private static final Logger LOGGER = Logger.getLogger(TableSimilarityService.class.getCanonicalName());
    private final EmbeddingService service;
    private final Collection<EmbeddedTable> allTables;

    private static double cosineSimilarity(RealVector v1, RealVector v2) {
        Objects.requireNonNull(v1, "No vector provided");
        Objects.requireNonNull(v2, "No vector provided");
        return v1.dotProduct(v2) / (v1.getNorm() * v2.getNorm());
    }

    public TableSimilarityService(EmbeddingService service) {
        this.service = Objects.requireNonNull(service, "No embedding service provided");
        this.allTables = new ArrayList<EmbeddedTable>();
    }

    public void addTable(EmbeddedTable table) {
        if (table != null) {
            this.allTables.add(table);
        }
    }

    public Collection<EmbeddedTable> query(String prompt, long maxTokens) {
        Utility.requireNotBlank((String)prompt, (String)"No prompt provided");
        TextEmbedding promptEmbedding = this.service.embed(prompt);
        ArrayList<TableSimilarity> similarities = new ArrayList<TableSimilarity>();
        for (EmbeddedTable embeddedTable : this.allTables) {
            if (!embeddedTable.hasEmbedding()) continue;
            TextEmbedding tableEmbedding = embeddedTable.getEmbedding();
            double cosineSimilarity = TableSimilarityService.cosineSimilarity(promptEmbedding.getEmbeddingVector(), tableEmbedding.getEmbeddingVector());
            similarities.add(new TableSimilarity(embeddedTable, cosineSimilarity));
        }
        Collections.sort(similarities);
        List<TableSimilarity> tableSimilarities = this.pruneToMaxTokens(similarities, maxTokens);
        ArrayList<EmbeddedTable> matchedTables = new ArrayList<EmbeddedTable>();
        for (TableSimilarity tableSimilarity : tableSimilarities) {
            matchedTables.add(tableSimilarity.getTable());
        }
        return matchedTables;
    }

    private List<TableSimilarity> pruneToMaxTokens(List<TableSimilarity> similarities, long maxTokens) {
        Objects.requireNonNull(similarities, "No similarities provided");
        long tokenSum = 0L;
        int index = -1;
        for (TableSimilarity tableSimilarity : similarities) {
            TextEmbedding embedding;
            if (tableSimilarity == null || (embedding = tableSimilarity.getTable().getEmbedding()) == null) continue;
            long tokenCount = embedding.getTokenCount();
            if (tokenSum + tokenCount >= maxTokens) break;
            tokenSum += tokenCount;
            ++index;
        }
        if (index == -1) {
            return new ArrayList<TableSimilarity>();
        }
        LOGGER.log(Level.CONFIG, (Supplier<String>)new StringFormat("Limiting to %d tables, with %d tokens", new Object[]{index, tokenSum}));
        return similarities.subList(0, index + 1);
    }
}

