package de.tu_dortmund.sfb876.optimplugin.costfunctions;

import org.apache.commons.math3.analysis.function.Exp;
import org.apache.commons.math3.analysis.function.Inverse;
import org.apache.commons.math3.analysis.function.Log;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:de/tu_dortmund/sfb876/optimplugin/costfunctions/LogLoss.class */
public class LogLoss implements CostFunction {
    @Override // de.tu_dortmund.sfb876.optimplugin.costfunctions.CostFunction
    public double computeCost(RealMatrix realMatrix, RealVector realVector, RealVector realVector2) {
        RealVector sigmoid = sigmoid(realMatrix, realVector2);
        return (1.0d / realMatrix.getRowDimension()) * realVector.mapMultiply(-1.0d).ebeMultiply(sigmoid.map(new Log())).subtract(realVector.mapMultiply(-1.0d).mapAdd(1.0d).ebeMultiply(sigmoid.mapMultiply(-1.0d).mapAdd(1.0d).map(new Log()))).getL1Norm();
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.costfunctions.CostFunction
    public RealVector predict(RealMatrix realMatrix, RealVector realVector) {
        RealVector sigmoid = sigmoid(realMatrix, realVector);
        for (int i = 0; i < sigmoid.getDimension(); i++) {
            if (sigmoid.getEntry(i) >= 0.5d) {
                sigmoid.set(1.0d);
            } else {
                sigmoid.set(CMAESOptimizer.DEFAULT_STOPFITNESS);
            }
        }
        return sigmoid;
    }

    private RealVector sigmoid(RealMatrix realMatrix, RealVector realVector) {
        return realMatrix.operate(realVector).mapMultiply(-1.0d).map(new Exp()).mapAdd(1.0d).map(new Inverse());
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.costfunctions.CostFunction
    public RealVector getGradient(RealMatrix realMatrix, RealVector realVector, RealVector realVector2) {
        return realMatrix.transpose().operate(sigmoid(realMatrix, realVector2).subtract(realVector)).mapDivide(realMatrix.getRowDimension());
    }
}
