/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import ai.djl.util.Preconditions;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Collections;

public class Multiplication
extends AbstractBlock {
    private static final byte VERSION = 1;
    private long units;
    private long inputFeatures;
    private Shape inputShape;
    private Parameter weight;

    Multiplication(Builder builder) {
        super((byte)1);
        this.units = builder.units;
        this.weight = this.addParameter(Parameter.builder().setName("weight").setType(Parameter.Type.WEIGHT).build());
    }

    @Override
    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDArray input = inputs.singletonOrThrow();
        Device device = input.getDevice();
        NDArray weightArr = parameterStore.getValue(this.weight, device, training);
        return this.multiply(input, weightArr);
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputs) {
        return new Shape[]{new Shape(this.units).addAll(inputs[0])};
    }

    @Override
    public PairList<String, Shape> describeInput() {
        return new PairList<String, Shape>(Collections.singletonList("linearInput"), Collections.singletonList(this.inputShape));
    }

    @Override
    protected void beforeInitialize(Shape ... inputShapes) {
        super.beforeInitialize(inputShapes);
        Preconditions.checkArgument(inputShapes.length == 1, "Linear block only support 1 input");
        Shape input = inputShapes[0];
        this.inputFeatures = input.slice(1).size();
        this.inputShape = input.slice(0, 1);
    }

    @Override
    public void prepare(Shape[] inputShapes) {
        Shape input = inputShapes[0];
        this.weight.setShape(new Shape(this.units, 1L).addAll(input.slice(1)));
    }

    @Override
    protected void saveMetadata(DataOutputStream os) throws IOException {
        os.writeLong(this.units);
        os.writeLong(this.inputFeatures);
        os.write(this.inputShape.getEncoded());
    }

    @Override
    public void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException {
        if (loadVersion != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + loadVersion);
        }
        this.units = is.readLong();
        this.inputFeatures = is.readLong();
        this.inputShape = Shape.decode(is);
    }

    public NDList multiply(NDArray input, NDArray weight) {
        NDArray resultArr = input.mul(weight);
        return new NDList(resultArr);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        private long units;

        Builder() {
        }

        public Builder setUnits(long units) {
            this.units = units;
            return this;
        }

        public Multiplication build() {
            Preconditions.checkArgument(this.units > 0L, "You must specify unit");
            return new Multiplication(this);
        }
    }
}

