/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.plugin.script;

import java.io.IOException;
import java.math.BigInteger;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil;
import org.opensearch.knn.index.query.KNNWeight;
import org.opensearch.knn.plugin.script.KNNScoreScript;
import org.opensearch.knn.plugin.script.KNNScoringSpaceUtil;
import org.opensearch.knn.plugin.script.KNNScoringUtil;
import org.opensearch.script.ScoreScript;
import org.opensearch.search.lookup.SearchLookup;

public interface KNNScoringSpace {
    public ScoreScript getScoreScript(Map<String, Object> var1, String var2, SearchLookup var3, LeafReaderContext var4, IndexSearcher var5) throws IOException;

    public static class HammingBit
    implements KNNScoringSpace {
        Object processedQuery;
        BiFunction<?, ?, Float> scoringMethod;

        public HammingBit(Object query, MappedFieldType fieldType) {
            if (KNNScoringSpaceUtil.isLongFieldType(fieldType)) {
                this.processedQuery = KNNScoringSpaceUtil.parseToLong(query);
                this.scoringMethod = (q, v) -> Float.valueOf(1.0f / (1.0f + KNNScoringUtil.calculateHammingBit(q, v)));
            } else if (KNNScoringSpaceUtil.isBinaryFieldType(fieldType)) {
                this.processedQuery = KNNScoringSpaceUtil.parseToBigInteger(query);
                this.scoringMethod = (q, v) -> Float.valueOf(1.0f / (1.0f + KNNScoringUtil.calculateHammingBit(q, v)));
            } else {
                throw new IllegalArgumentException("Incompatible field_type for hammingbit space. The field type must of type long or binary.");
            }
        }

        @Override
        public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher) throws IOException {
            if (this.processedQuery instanceof Long) {
                return new KNNScoreScript.LongType(params, (Long)this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher);
            }
            return new KNNScoreScript.BigIntegerType(params, (BigInteger)this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher);
        }
    }

    public static class Hamming
    extends KNNFieldSpace {
        private static final Set<VectorDataType> DATA_TYPES_HAMMING = Set.of(VectorDataType.BINARY);

        public Hamming(Object query, MappedFieldType fieldType) {
            super(query, fieldType, "hamming", DATA_TYPES_HAMMING);
        }

        @Override
        protected BiFunction<float[], float[], Float> getScoringMethod(float[] processedQuery) {
            return (q, v) -> Float.valueOf(1.0f / (1.0f + KNNScoringUtil.calculateHammingBit(this.toByte((float[])q), this.toByte((float[])v))));
        }

        private byte[] toByte(float[] vector) {
            byte[] bytes = new byte[vector.length];
            for (int i = 0; i < vector.length; ++i) {
                bytes[i] = (byte)vector[i];
            }
            return bytes;
        }
    }

    public static class InnerProd
    extends KNNFieldSpace {
        public InnerProd(Object query, MappedFieldType fieldType) {
            super(query, fieldType, "innerproduct");
        }

        @Override
        protected BiFunction<float[], float[], Float> getScoringMethod(float[] processedQuery) {
            return (q, v) -> Float.valueOf(KNNWeight.normalizeScore(-KNNScoringUtil.innerProduct(q, v)));
        }
    }

    public static class LInf
    extends KNNFieldSpace {
        public LInf(Object query, MappedFieldType fieldType) {
            super(query, fieldType, "l-inf");
        }

        @Override
        protected BiFunction<float[], float[], Float> getScoringMethod(float[] processedQuery) {
            return (q, v) -> Float.valueOf(1.0f / (1.0f + KNNScoringUtil.lInfNorm(q, v)));
        }
    }

    public static class L1
    extends KNNFieldSpace {
        public L1(Object query, MappedFieldType fieldType) {
            super(query, fieldType, "l1");
        }

        @Override
        protected BiFunction<float[], float[], Float> getScoringMethod(float[] processedQuery) {
            return (q, v) -> Float.valueOf(1.0f / (1.0f + KNNScoringUtil.l1Norm(q, v)));
        }
    }

    public static class CosineSimilarity
    extends KNNFieldSpace {
        public CosineSimilarity(Object query, MappedFieldType fieldType) {
            super(query, fieldType, "cosine");
        }

        @Override
        protected BiFunction<float[], float[], Float> getScoringMethod(float[] processedQuery) {
            SpaceType.COSINESIMIL.validateVector(processedQuery);
            float qVectorSquaredMagnitude = KNNScoringSpaceUtil.getVectorMagnitudeSquared(processedQuery);
            return (q, v) -> Float.valueOf(1.0f + KNNScoringUtil.cosinesimilOptimized(q, v, qVectorSquaredMagnitude));
        }
    }

    public static class L2
    extends KNNFieldSpace {
        public L2(Object query, MappedFieldType fieldType) {
            super(query, fieldType, "l2");
        }

        @Override
        public BiFunction<float[], float[], Float> getScoringMethod(float[] processedQuery) {
            return (q, v) -> Float.valueOf(1.0f / (1.0f + KNNScoringUtil.l2Squared(q, v)));
        }
    }

    public static abstract class KNNFieldSpace
    implements KNNScoringSpace {
        public static final Set<VectorDataType> DATA_TYPES_DEFAULT = Set.of(VectorDataType.FLOAT, VectorDataType.BYTE);
        float[] processedQuery;
        BiFunction<float[], float[], Float> scoringMethod;

        public KNNFieldSpace(Object query, MappedFieldType fieldType, String spaceName) {
            this(query, fieldType, spaceName, DATA_TYPES_DEFAULT);
        }

        public KNNFieldSpace(Object query, MappedFieldType fieldType, String spaceName, Set<VectorDataType> supportingVectorDataTypes) {
            KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = this.toKNNVectorFieldType(fieldType, spaceName, supportingVectorDataTypes);
            this.processedQuery = this.getProcessedQuery(query, knnVectorFieldType);
            this.scoringMethod = this.getScoringMethod(this.processedQuery);
        }

        @Override
        public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup, LeafReaderContext ctx, IndexSearcher searcher) throws IOException {
            return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, ctx, searcher);
        }

        private KNNVectorFieldMapper.KNNVectorFieldType toKNNVectorFieldType(MappedFieldType fieldType, String spaceName, Set<VectorDataType> supportingVectorDataTypes) {
            VectorDataType vectorDataType;
            if (!KNNScoringSpaceUtil.isKNNVectorFieldType(fieldType)) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "Incompatible field_type for %s space. The field type must be knn_vector.", spaceName));
            }
            KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldMapper.KNNVectorFieldType)fieldType;
            VectorDataType vectorDataType2 = vectorDataType = knnVectorFieldType.getVectorDataType() == null ? VectorDataType.FLOAT : knnVectorFieldType.getVectorDataType();
            if (!supportingVectorDataTypes.contains((Object)vectorDataType)) {
                throw new IllegalArgumentException(String.format(Locale.ROOT, "Incompatible field_type for %s space. The data type should be %s but got %s", new Object[]{spaceName, supportingVectorDataTypes, vectorDataType}));
            }
            return knnVectorFieldType;
        }

        protected float[] getProcessedQuery(Object query, KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType) {
            return KNNScoringSpaceUtil.parseToFloatArray(query, KNNVectorFieldMapperUtil.getExpectedVectorLength(knnVectorFieldType), knnVectorFieldType.getVectorDataType());
        }

        protected abstract BiFunction<float[], float[], Float> getScoringMethod(float[] var1);
    }
}

