/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.search.query;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
import org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QueryPhase;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.TopDocsCollectorContext;
import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.search.sort.SortAndFormats;

public class HybridQueryPhaseSearcher
extends QueryPhase.DefaultQueryPhaseSearcher {
    @Generated
    private static final Logger log = LogManager.getLogger(HybridQueryPhaseSearcher.class);

    public boolean searchWith(SearchContext searchContext, ContextIndexSearcher searcher, Query query, LinkedList<QueryCollectorContext> collectors, boolean hasFilterCollector, boolean hasTimeout) throws IOException {
        if (query instanceof HybridQuery) {
            return this.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
        }
        return super.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
    }

    @VisibleForTesting
    protected boolean searchWithCollector(SearchContext searchContext, ContextIndexSearcher searcher, Query query, LinkedList<QueryCollectorContext> collectors, boolean hasFilterCollector, boolean hasTimeout) throws IOException {
        boolean shouldRescore;
        log.debug("searching with custom doc collector, shard {}", (Object)searchContext.shardTarget().getShardId());
        TopDocsCollectorContext topDocsFactory = TopDocsCollectorContext.createTopDocsCollectorContext((SearchContext)searchContext, (boolean)hasFilterCollector);
        collectors.addFirst((QueryCollectorContext)topDocsFactory);
        if (searchContext.size() == 0) {
            TotalHitCountCollector collector = new TotalHitCountCollector();
            searcher.search(query, (Collector)collector);
            return false;
        }
        IndexReader reader = searchContext.searcher().getIndexReader();
        int totalNumDocs = Math.max(0, reader.numDocs());
        int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
        boolean bl = shouldRescore = !searchContext.rescore().isEmpty();
        if (shouldRescore) {
            for (RescoreContext rescoreContext : searchContext.rescore()) {
                numDocs = Math.max(numDocs, rescoreContext.getWindowSize());
            }
        }
        QuerySearchResult queryResult = searchContext.queryResult();
        HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector(numDocs, new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())));
        searcher.search(query, (Collector)collector);
        if (searchContext.terminateAfter() != 0 && queryResult.terminatedEarly() == null) {
            queryResult.terminatedEarly(false);
        }
        this.setTopDocsInQueryResult(queryResult, collector, searchContext);
        return shouldRescore;
    }

    private void setTopDocsInQueryResult(QuerySearchResult queryResult, HybridTopScoreDocCollector collector, SearchContext searchContext) {
        List<TopDocs> topDocs = collector.topDocs();
        float maxScore = this.getMaxScore(topDocs);
        boolean isSingleShard = searchContext.numberOfShards() == 1;
        TopDocs newTopDocs = this.getNewTopDocs(this.getTotalHits(searchContext, topDocs, isSingleShard), topDocs);
        TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
        queryResult.topDocs(topDocsAndMaxScore, this.getSortValueFormats(searchContext.sort()));
    }

    private TopDocs getNewTopDocs(TotalHits totalHits, List<TopDocs> topDocs) {
        ScoreDoc[] scoreDocs = new ScoreDoc[]{};
        if (Objects.nonNull(topDocs)) {
            int delimiterDocId = topDocs.stream().filter(Objects::nonNull).filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)).map(topDoc -> topDoc.scoreDocs).filter(scoreDoc -> ((ScoreDoc[])scoreDoc).length > 0).map(scoreDoc -> scoreDoc[0].doc).findFirst().orElse(-1);
            if (delimiterDocId == -1) {
                return new TopDocs(totalHits, scoreDocs);
            }
            ArrayList<ScoreDoc> result = new ArrayList<ScoreDoc>();
            result.add(HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(delimiterDocId));
            for (TopDocs topDoc2 : topDocs) {
                if (Objects.isNull(topDoc2) || Objects.isNull(topDoc2.scoreDocs)) {
                    result.add(HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults(delimiterDocId));
                    continue;
                }
                result.add(HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults(delimiterDocId));
                result.addAll(Arrays.asList(topDoc2.scoreDocs));
            }
            result.add(HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(delimiterDocId));
            scoreDocs = (ScoreDoc[])result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new);
        }
        return new TopDocs(totalHits, scoreDocs);
    }

    private TotalHits getTotalHits(SearchContext searchContext, List<TopDocs> topDocs, boolean isSingleShard) {
        TotalHits.Relation relation;
        int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();
        TotalHits.Relation relation2 = relation = trackTotalHitsUpTo == -1 ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO : TotalHits.Relation.EQUAL_TO;
        if (topDocs == null || topDocs.isEmpty()) {
            return new TotalHits(0L, relation);
        }
        long maxTotalHits = topDocs.get((int)0).totalHits.value;
        int totalSize = 0;
        for (TopDocs topDoc : topDocs) {
            maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value);
            if (!isSingleShard) continue;
            totalSize = (int)((long)totalSize + (topDoc.totalHits.value + 1L));
        }
        totalSize += 2;
        if (isSingleShard) {
            searchContext.size(totalSize);
        }
        return new TotalHits(maxTotalHits, relation);
    }

    private float getMaxScore(List<TopDocs> topDocs) {
        if (topDocs.isEmpty()) {
            return 0.0f;
        }
        return topDocs.stream().map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]).map(scoreDoc -> Float.valueOf(scoreDoc.score)).max(Float::compare).get().floatValue();
    }

    private DocValueFormat[] getSortValueFormats(SortAndFormats sortAndFormats) {
        return sortAndFormats == null ? null : sortAndFormats.formats;
    }
}

