/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math;

import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.math.FeedForwardParameters;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.util.HeapMerger;
import org.tribuo.math.util.Merger;

public class LinearParameters
implements FeedForwardParameters {
    private static final long serialVersionUID = 1L;
    private static final Merger merger = new HeapMerger();
    private Tensor[] weights;
    private DenseMatrix weightMatrix;

    public LinearParameters(int numFeatures, int numLabels) {
        this.weights = new Tensor[1];
        this.weightMatrix = new DenseMatrix(numLabels, numFeatures);
        this.weights[0] = this.weightMatrix;
    }

    public LinearParameters(DenseMatrix weightMatrix) {
        this.weightMatrix = weightMatrix;
        this.weights = new Tensor[1];
        this.weights[0] = weightMatrix;
    }

    @Override
    public DenseVector predict(SGDVector example) {
        return this.weightMatrix.leftMultiply(example);
    }

    @Override
    public Tensor[] gradients(Pair<Double, SGDVector> score, SGDVector features) {
        Tensor[] output = new Tensor[]{((SGDVector)score.getB()).outer(features)};
        return output;
    }

    @Override
    public Tensor[] getEmptyCopy() {
        DenseMatrix matrix = new DenseMatrix(this.weightMatrix.getDimension1Size(), this.weightMatrix.getDimension2Size());
        Tensor[] output = new Tensor[]{matrix};
        return output;
    }

    @Override
    public Tensor[] get() {
        return this.weights;
    }

    public DenseMatrix getWeightMatrix() {
        return this.weightMatrix;
    }

    @Override
    public void set(Tensor[] newWeights) {
        if (newWeights.length == this.weights.length) {
            this.weights = newWeights;
            this.weightMatrix = (DenseMatrix)this.weights[0];
        }
    }

    @Override
    public void update(Tensor[] gradients) {
        for (int i = 0; i < gradients.length; ++i) {
            this.weights[i].intersectAndAddInPlace(gradients[i]);
        }
    }

    @Override
    public Tensor[] merge(Tensor[][] gradients, int size) {
        if (gradients[0][0] instanceof DenseMatrix) {
            for (int i = 1; i < size; ++i) {
                gradients[0][0].intersectAndAddInPlace(gradients[i][0]);
            }
            return new Tensor[]{gradients[0][0]};
        }
        if (gradients[0][0] instanceof DenseSparseMatrix) {
            DenseSparseMatrix[] updates = new DenseSparseMatrix[size];
            for (int j = 0; j < updates.length; ++j) {
                updates[j] = (DenseSparseMatrix)gradients[j][0];
            }
            DenseSparseMatrix update = merger.merge(updates);
            return new Tensor[]{update};
        }
        throw new IllegalStateException("Unexpected gradient type, expected DenseMatrix or DenseSparseMatrix, received " + gradients[0][0].getClass().getName());
    }

    @Override
    public LinearParameters copy() {
        return new LinearParameters(this.weightMatrix.copy());
    }
}

