/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.modality.cv.MultiBoxTarget;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.loss.AbstractCompositeLoss;
import ai.djl.training.loss.Loss;
import ai.djl.util.Pair;
import java.util.Arrays;

public class SingleShotDetectionLoss
extends AbstractCompositeLoss {
    private MultiBoxTarget multiBoxTarget = MultiBoxTarget.builder().build();

    public SingleShotDetectionLoss() {
        super("SingleShotDetectionLoss");
        this.components = Arrays.asList(Loss.softmaxCrossEntropyLoss("ClassLoss"), Loss.l1Loss("BoundingBoxLoss"));
    }

    @Override
    protected Pair<NDList, NDList> inputForComponent(int componentIndex, NDList labels, NDList predictions) {
        NDArray anchors = (NDArray)predictions.get(0);
        NDArray classPredictions = (NDArray)predictions.get(1);
        NDList targets = this.multiBoxTarget.target(new NDList(anchors, labels.head(), classPredictions.transpose(0, 2, 1)));
        switch (componentIndex) {
            case 0: {
                NDArray classLabels = (NDArray)targets.get(2);
                return new Pair<NDList, NDList>(new NDList(classLabels), new NDList(classPredictions));
            }
            case 1: {
                NDArray boundingBoxPredictions = (NDArray)predictions.get(2);
                NDArray boundingBoxLabels = (NDArray)targets.get(0);
                NDArray boundingBoxMasks = (NDArray)targets.get(1);
                return new Pair<NDList, NDList>(new NDList(boundingBoxLabels.mul(boundingBoxMasks)), new NDList(boundingBoxPredictions.mul(boundingBoxMasks)));
            }
        }
        throw new IllegalArgumentException("Invalid component index");
    }
}

