/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.training;

import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang.ArrayUtils;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.training.TrainingDataConsumer;
import org.opensearch.search.SearchHit;

public class FloatTrainingDataConsumer
extends TrainingDataConsumer {
    public FloatTrainingDataConsumer(NativeMemoryAllocation.TrainingDataAllocation trainingDataAllocation) {
        super(trainingDataAllocation);
    }

    @Override
    public void accept(List<?> floats) {
        this.trainingDataAllocation.setMemoryAddress(JNIService.transferVectors(this.trainingDataAllocation.getMemoryAddress(), (float[][])floats.stream().map(v -> ArrayUtils.toPrimitive((Float[])((Float[])v))).toArray(x$0 -> new float[x$0][])));
    }

    @Override
    public void processTrainingVectors(SearchResponse searchResponse, int vectorsToAdd, String fieldName) {
        SearchHit[] hits = searchResponse.getHits().getHits();
        ArrayList<Float[]> vectors = new ArrayList<Float[]>();
        String[] fieldPath = fieldName.split("\\.");
        for (int vector = 0; vector < vectorsToAdd; ++vector) {
            Object fieldValue = this.extractFieldValue(hits[vector], fieldPath);
            if (!(fieldValue instanceof List)) continue;
            List fieldList = (List)fieldValue;
            vectors.add((Float[])fieldList.stream().map(Number::floatValue).toArray(Float[]::new));
        }
        this.setTotalVectorsCountAdded(this.getTotalVectorsCountAdded() + vectors.size());
        this.accept(vectors);
    }
}

