/*
 * Decompiled with CFR 0.152.
 */
package org.ojalgo.ann;

import java.util.function.DoubleUnaryOperator;
import org.ojalgo.ann.ArtificialNeuralNetwork;
import org.ojalgo.function.constant.PrimitiveMath;

final class TrainingConfiguration {
    boolean dropouts = false;
    ArtificialNeuralNetwork.Error error = ArtificialNeuralNetwork.Error.HALF_SQUARED_DIFFERENCE;
    double learningRate = PrimitiveMath.ONE;
    boolean regularisationL1 = false;
    double regularisationL1Factor = PrimitiveMath.ZERO;
    boolean regularisationL2 = false;
    double regularisationL2Factor = PrimitiveMath.ZERO;

    TrainingConfiguration() {
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof TrainingConfiguration)) {
            return false;
        }
        TrainingConfiguration other = (TrainingConfiguration)obj;
        if (this.dropouts != other.dropouts) {
            return false;
        }
        if (this.error != other.error) {
            return false;
        }
        if (Double.doubleToLongBits(this.learningRate) != Double.doubleToLongBits(other.learningRate)) {
            return false;
        }
        if (this.regularisationL1 != other.regularisationL1) {
            return false;
        }
        if (Double.doubleToLongBits(this.regularisationL1Factor) != Double.doubleToLongBits(other.regularisationL1Factor)) {
            return false;
        }
        if (this.regularisationL2 != other.regularisationL2) {
            return false;
        }
        return Double.doubleToLongBits(this.regularisationL2Factor) == Double.doubleToLongBits(other.regularisationL2Factor);
    }

    public int hashCode() {
        int prime = 31;
        int result = 1;
        result = 31 * result + (this.dropouts ? 1231 : 1237);
        result = 31 * result + (this.error == null ? 0 : this.error.hashCode());
        long temp = Double.doubleToLongBits(this.learningRate);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        result = 31 * result + (this.regularisationL1 ? 1231 : 1237);
        temp = Double.doubleToLongBits(this.regularisationL1Factor);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        result = 31 * result + (this.regularisationL2 ? 1231 : 1237);
        temp = Double.doubleToLongBits(this.regularisationL2Factor);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        return result;
    }

    private double doL1(double current) {
        if (current < PrimitiveMath.ZERO) {
            return -this.regularisationL1Factor;
        }
        return this.regularisationL1Factor;
    }

    private double doL2(double current) {
        return this.regularisationL2Factor * current;
    }

    double probabilityDidKeepInput(int layer) {
        if (this.dropouts && layer != 0) {
            return PrimitiveMath.HALF;
        }
        return PrimitiveMath.ONE;
    }

    double probabilityWillKeepOutput(int layer, int depth) {
        if (this.dropouts && layer < depth - 1) {
            return PrimitiveMath.HALF;
        }
        return PrimitiveMath.ONE;
    }

    DoubleUnaryOperator regularisation() {
        if (this.regularisationL2) {
            if (this.regularisationL1) {
                return current -> this.doL1(current) + this.doL2(current);
            }
            return this::doL2;
        }
        if (this.regularisationL1) {
            return this::doL1;
        }
        return null;
    }
}

