/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor.normalization;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import lombok.Generated;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;

public class L2ScoreNormalizationTechnique
implements ScoreNormalizationTechnique {
    public static final String TECHNIQUE_NAME = "l2";
    private static final float MIN_SCORE = 0.0f;

    @Override
    public void normalize(List<CompoundTopDocs> queryTopDocs) {
        List<Float> normsPerSubquery = this.getL2Norm(queryTopDocs);
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            for (int j = 0; j < topDocsPerSubQuery.size(); ++j) {
                TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
                for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
                    scoreDoc.score = this.normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j).floatValue());
                }
            }
        }
    }

    private List<Float> getL2Norm(List<CompoundTopDocs> queryTopDocs) {
        int numOfSubqueries = queryTopDocs.stream().filter(Objects::nonNull).filter(topDocs -> topDocs.getTopDocs().size() > 0).findAny().get().getTopDocs().size();
        float[] l2Norms = new float[numOfSubqueries];
        for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
            if (Objects.isNull(compoundQueryTopDocs)) continue;
            List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
            int bound = topDocsPerSubQuery.size();
            for (int index = 0; index < bound; ++index) {
                for (ScoreDoc scoreDocs : topDocsPerSubQuery.get((int)index).scoreDocs) {
                    int n = index;
                    l2Norms[n] = l2Norms[n] + scoreDocs.score * scoreDocs.score;
                }
            }
        }
        for (int index = 0; index < l2Norms.length; ++index) {
            l2Norms[index] = (float)Math.sqrt(l2Norms[index]);
        }
        ArrayList<Float> l2NormList = new ArrayList<Float>();
        for (int index = 0; index < numOfSubqueries; ++index) {
            l2NormList.add(Float.valueOf(l2Norms[index]));
        }
        return l2NormList;
    }

    private float normalizeSingleScore(float score, float l2Norm) {
        return l2Norm == 0.0f ? 0.0f : score / l2Norm;
    }

    @Generated
    public String toString() {
        return "L2ScoreNormalizationTechnique(TECHNIQUE_NAME=l2)";
    }
}

