package cc.mallet.regression;

import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.OrthantWiseLimitedMemoryBFGS;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import java.io.File;
import java.text.NumberFormat;
import java.util.Arrays;

/* loaded from: input_file:cc/mallet/regression/LinearRegressionTrainer.class */
public class LinearRegressionTrainer implements Optimizable.ByGradientValue {
    LinearRegression regression;
    double[] parameters;
    InstanceList trainingData;
    double[] residuals;
    boolean cachedResidualsStale = true;
    NumberFormat formatter;
    int precisionIndex;
    int interceptIndex;

    public LinearRegressionTrainer(InstanceList instanceList) {
        this.trainingData = instanceList;
        this.regression = new LinearRegression(this.trainingData.getDataAlphabet());
        this.parameters = this.regression.getParameters();
        this.interceptIndex = this.parameters.length - 2;
        this.precisionIndex = this.parameters.length - 1;
        this.residuals = new double[this.trainingData.size()];
        this.parameters[this.precisionIndex] = 1.0d;
        this.formatter = NumberFormat.getInstance();
        this.formatter.setMaximumFractionDigits(3);
    }

    private void computeResiduals() {
        for (int i = 0; i < this.trainingData.size(); i++) {
            Instance instance = this.trainingData.get(i);
            this.residuals[i] = ((Double) instance.getTarget()).doubleValue();
            FeatureVector featureVector = (FeatureVector) instance.getData();
            for (int i2 = 0; i2 < featureVector.numLocations(); i2++) {
                int indexAtLocation = featureVector.indexAtLocation(i2);
                double[] dArr = this.residuals;
                int i3 = i;
                dArr[i3] = dArr[i3] - (this.parameters[indexAtLocation] * featureVector.valueAtLocation(i2));
            }
            double[] dArr2 = this.residuals;
            int i4 = i;
            dArr2[i4] = dArr2[i4] - this.parameters[this.interceptIndex];
        }
        this.cachedResidualsStale = false;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public double getValue() {
        for (int i = 0; i < this.parameters.length; i++) {
            System.out.println(this.parameters[i]);
        }
        System.out.println();
        double length = 0.0d + ((this.residuals.length / 2.0d) * Math.log(this.parameters[this.precisionIndex]));
        computeResiduals();
        for (int i2 = 0; i2 < this.residuals.length; i2++) {
            length -= ((this.parameters[this.precisionIndex] * this.parameters[this.precisionIndex]) * (this.residuals[i2] * this.residuals[i2])) / 2.0d;
        }
        for (int i3 = 0; i3 < this.parameters.length - 1; i3++) {
        }
        return length;
    }

    @Override // cc.mallet.optimize.Optimizable.ByGradientValue
    public void getValueGradient(double[] dArr) {
        computeResiduals();
        Arrays.fill(dArr, 0.0d);
        int i = this.precisionIndex;
        dArr[i] = dArr[i] + ((0.5d * this.residuals.length) / this.parameters[this.precisionIndex]);
        for (int i2 = 0; i2 < this.residuals.length; i2++) {
            FeatureVector featureVector = (FeatureVector) this.trainingData.get(i2).getData();
            for (int i3 = 0; i3 < featureVector.numLocations(); i3++) {
                int indexAtLocation = featureVector.indexAtLocation(i3);
                if (indexAtLocation == 3) {
                    dArr[indexAtLocation] = dArr[indexAtLocation] + (this.parameters[this.precisionIndex] * this.parameters[this.precisionIndex] * this.residuals[i2] * featureVector.valueAtLocation(i3));
                }
            }
            int i4 = this.interceptIndex;
            dArr[i4] = dArr[i4] + (this.parameters[this.precisionIndex] * this.parameters[this.precisionIndex] * this.residuals[i2]);
            int i5 = this.precisionIndex;
            dArr[i5] = dArr[i5] - ((this.parameters[this.precisionIndex] * this.residuals[i2]) * this.residuals[i2]);
        }
        for (int i6 = 0; i6 < this.parameters.length - 1; i6++) {
        }
        for (int i7 = 0; i7 < this.parameters.length; i7++) {
            System.out.println("G\t" + dArr[i7] + "\t" + this.parameters[i7]);
        }
        System.out.println();
    }

    @Override // cc.mallet.optimize.Optimizable
    public int getNumParameters() {
        return this.parameters.length;
    }

    @Override // cc.mallet.optimize.Optimizable
    public double getParameter(int i) {
        return this.parameters[i];
    }

    @Override // cc.mallet.optimize.Optimizable
    public void getParameters(double[] dArr) {
        for (int i = 0; i < this.parameters.length; i++) {
            dArr[i] = this.parameters[i];
        }
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameter(int i, double d) {
        if (i == this.precisionIndex && d <= 0.0d) {
            System.err.println("attempted to set precision at or less than 0");
            d = 0.001d;
        }
        this.cachedResidualsStale = true;
        this.parameters[i] = d;
    }

    @Override // cc.mallet.optimize.Optimizable
    public void setParameters(double[] dArr) {
        this.cachedResidualsStale = true;
        for (int i = 0; i < this.parameters.length; i++) {
            if (i != this.precisionIndex || dArr[i] > 0.0d) {
                this.parameters[i] = dArr[i];
            } else {
                System.err.println("attempted to set precision at or less than 0");
                this.parameters[i] = 0.001d;
            }
        }
    }

    public static void main(String[] strArr) throws Exception {
        OrthantWiseLimitedMemoryBFGS orthantWiseLimitedMemoryBFGS = new OrthantWiseLimitedMemoryBFGS(new LinearRegressionTrainer(InstanceList.load(new File(strArr[0]))));
        orthantWiseLimitedMemoryBFGS.optimize();
        orthantWiseLimitedMemoryBFGS.optimize();
    }
}
