/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.securityanalytics.correlation.index.query;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.Query;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.NamedWriteable;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentLocation;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.securityanalytics.correlation.index.mapper.CorrelationVectorFieldMapper;
import org.opensearch.securityanalytics.correlation.index.query.CorrelationQueryFactory;

public class CorrelationQueryBuilder
extends AbstractQueryBuilder<CorrelationQueryBuilder> {
    private static final Logger log = LogManager.getLogger(CorrelationQueryBuilder.class);
    public static final ParseField VECTOR_FIELD = new ParseField("vector", new String[0]);
    public static final ParseField K_FIELD = new ParseField("k", new String[0]);
    public static final ParseField FILTER_FIELD = new ParseField("filter", new String[0]);
    public static int K_MAX = 10000;
    public static final String NAME = "correlation";
    private final String fieldName;
    private final float[] vector;
    private int k = 0;
    private QueryBuilder filter;

    public CorrelationQueryBuilder(String fieldName, float[] vector, int k) {
        this(fieldName, vector, k, null);
    }

    public CorrelationQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder filter) {
        if (Strings.isNullOrEmpty((String)fieldName)) {
            throw new IllegalArgumentException(String.format(Locale.getDefault(), "[%s] requires fieldName", NAME));
        }
        if (vector == null) {
            throw new IllegalArgumentException(String.format(Locale.getDefault(), "[%s] requires query vector", NAME));
        }
        if (vector.length == 0) {
            throw new IllegalArgumentException(String.format(Locale.getDefault(), "[%s] query vector is empty", NAME));
        }
        if (k <= 0) {
            throw new IllegalArgumentException(String.format(Locale.getDefault(), "[%s] requires k > 0", NAME));
        }
        if (k > K_MAX) {
            throw new IllegalArgumentException(String.format(Locale.getDefault(), "[%s] requires k <= ", K_MAX));
        }
        this.fieldName = fieldName;
        this.vector = vector;
        this.k = k;
        this.filter = filter;
    }

    public CorrelationQueryBuilder(StreamInput sin) throws IOException {
        super(sin);
        try {
            this.fieldName = sin.readString();
            this.vector = sin.readFloatArray();
            this.k = sin.readInt();
            this.filter = (QueryBuilder)sin.readOptionalNamedWriteable(QueryBuilder.class);
        }
        catch (IOException ex) {
            throw new RuntimeException("Unable to create CorrelationQueryBuilder", ex);
        }
    }

    private static float[] objectsToFloats(List<Object> objs) {
        float[] vector = new float[objs.size()];
        for (int i = 0; i < objs.size(); ++i) {
            vector[i] = ((Number)objs.get(i)).floatValue();
        }
        return vector;
    }

    public static CorrelationQueryBuilder fromXContent(XContentParser parser) throws IOException {
        XContentParser.Token token;
        String fieldName = null;
        List vector = null;
        float boost = 1.0f;
        int k = 0;
        QueryBuilder filter = null;
        String queryName = null;
        String currentFieldName = null;
        while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
            if (token == XContentParser.Token.FIELD_NAME) {
                currentFieldName = parser.currentName();
                continue;
            }
            if (token == XContentParser.Token.START_OBJECT) {
                CorrelationQueryBuilder.throwParsingExceptionOnMultipleFields((String)NAME, (XContentLocation)parser.getTokenLocation(), fieldName, (String)currentFieldName);
                fieldName = currentFieldName;
                while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
                    if (token == XContentParser.Token.FIELD_NAME) {
                        currentFieldName = parser.currentName();
                        continue;
                    }
                    if (token.isValue() || token == XContentParser.Token.START_ARRAY) {
                        if (VECTOR_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            vector = parser.list();
                            continue;
                        }
                        if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            boost = parser.floatValue();
                            continue;
                        }
                        if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            k = (Integer)NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
                            continue;
                        }
                        if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            queryName = parser.text();
                            continue;
                        }
                        throw new ParsingException(parser.getTokenLocation(), "[correlation] query does not support [" + currentFieldName + "]", new Object[0]);
                    }
                    if (token == XContentParser.Token.START_OBJECT) {
                        String tokenName = parser.currentName();
                        if (FILTER_FIELD.getPreferredName().equals(tokenName)) {
                            filter = CorrelationQueryBuilder.parseInnerQueryBuilder((XContentParser)parser);
                            continue;
                        }
                        throw new ParsingException(parser.getTokenLocation(), "[correlation] unknown token [" + token + "]", new Object[0]);
                    }
                    throw new ParsingException(parser.getTokenLocation(), "[correlation] unknown token [" + token + "] after [" + currentFieldName + "]", new Object[0]);
                }
                continue;
            }
            CorrelationQueryBuilder.throwParsingExceptionOnMultipleFields((String)NAME, (XContentLocation)parser.getTokenLocation(), fieldName, (String)parser.currentName());
            fieldName = parser.currentName();
            vector = parser.list();
        }
        assert (vector != null);
        CorrelationQueryBuilder correlationQueryBuilder = new CorrelationQueryBuilder(fieldName, CorrelationQueryBuilder.objectsToFloats(vector), k, filter);
        correlationQueryBuilder.queryName(queryName);
        correlationQueryBuilder.boost(boost);
        return correlationQueryBuilder;
    }

    public String fieldName() {
        return this.fieldName;
    }

    public Object vector() {
        return this.vector;
    }

    public int getK() {
        return this.k;
    }

    public QueryBuilder getFilter() {
        return this.filter;
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeString(this.fieldName);
        out.writeFloatArray(this.vector);
        out.writeInt(this.k);
        out.writeOptionalNamedWriteable((NamedWriteable)this.filter);
    }

    public void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject(NAME);
        builder.startObject(this.fieldName);
        builder.field(VECTOR_FIELD.getPreferredName(), (Object)this.vector);
        builder.field(K_FIELD.getPreferredName(), this.k);
        if (this.filter != null) {
            builder.field(FILTER_FIELD.getPreferredName(), (ToXContent)this.filter);
        }
        this.printBoostAndQueryName(builder);
        builder.endObject();
        builder.endObject();
    }

    protected Query doToQuery(QueryShardContext context) throws IOException {
        MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName);
        if (!(mappedFieldType instanceof CorrelationVectorFieldMapper.CorrelationVectorFieldType)) {
            throw new IllegalArgumentException(String.format(Locale.getDefault(), "Field '%s' is not knn_vector type.", this.fieldName));
        }
        CorrelationVectorFieldMapper.CorrelationVectorFieldType correlationVectorFieldType = (CorrelationVectorFieldMapper.CorrelationVectorFieldType)mappedFieldType;
        int fieldDimension = correlationVectorFieldType.getDimension();
        if (fieldDimension != this.vector.length) {
            throw new IllegalArgumentException(String.format(Locale.getDefault(), "Query vector has invalid dimension: %d. Dimension should be: %d", this.vector.length, fieldDimension));
        }
        String indexName = context.index().getName();
        CorrelationQueryFactory.CreateQueryRequest createQueryRequest = new CorrelationQueryFactory.CreateQueryRequest(indexName, this.fieldName, this.vector, this.k, this.filter, context);
        return CorrelationQueryFactory.create(createQueryRequest);
    }

    protected boolean doEquals(CorrelationQueryBuilder other) {
        return Objects.equals(this.fieldName, other.fieldName) && Arrays.equals(this.vector, other.vector) && Objects.equals(this.k, other.k);
    }

    protected int doHashCode() {
        return Objects.hash(this.fieldName, this.vector, this.k);
    }

    public String getWriteableName() {
        return NAME;
    }
}

