/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.cv;

import ai.djl.ndarray.NDList;

public class MultiBoxDetection {
    private boolean clip;
    private float threshold;
    private int backgroundId;
    private float nmsThreshold;
    private boolean forceSuppress;
    private int nmsTopK;

    public MultiBoxDetection(Builder builder) {
        this.clip = builder.clip;
        this.threshold = builder.threshold;
        this.backgroundId = builder.backgroundId;
        this.nmsThreshold = builder.nmsThreshold;
        this.forceSuppress = builder.forceSuppress;
        this.nmsTopK = builder.nmsTopK;
    }

    public NDList detection(NDList inputs) {
        if (inputs == null || inputs.size() != 3) {
            throw new IllegalArgumentException("NDList must contain class probabilities, box predictions, and anchors");
        }
        return inputs.head().getNDArrayInternal().multiBoxDetection(inputs, this.clip, this.threshold, this.backgroundId, this.nmsThreshold, this.forceSuppress, this.nmsTopK);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        boolean clip = true;
        private float threshold = 0.01f;
        int backgroundId;
        private float nmsThreshold = 0.5f;
        boolean forceSuppress;
        private int nmsTopK = -1;

        Builder() {
        }

        public Builder optClip(boolean clip) {
            this.clip = clip;
            return this;
        }

        public Builder optForceSuppress(boolean forceSuppress) {
            this.forceSuppress = forceSuppress;
            return this;
        }

        public Builder optBackgroundId(int backgroundId) {
            this.backgroundId = backgroundId;
            return this;
        }

        public Builder optNmsTopK(int nmsTopK) {
            this.nmsTopK = nmsTopK;
            return this;
        }

        public Builder optThreshold(float threshold) {
            this.threshold = threshold;
            return this;
        }

        public Builder optNmsThreshold(float nmsThreshold) {
            this.nmsThreshold = nmsThreshold;
            return this;
        }

        public MultiBoxDetection build() {
            return new MultiBoxDetection(this);
        }
    }
}

