package org.encogx.ml.fitting.linear;

import org.encogx.EncogError;
import org.encogx.ml.MLMethod;
import org.encogx.ml.TrainingImplementationType;
import org.encogx.ml.data.MLDataPair;
import org.encogx.ml.data.MLDataSet;
import org.encogx.ml.train.BasicTraining;
import org.encogx.neural.networks.training.propagation.TrainingContinuation;
import org.encogx.util.simple.EncogUtility;

/* loaded from: input_file:org/encogx/ml/fitting/linear/TrainLinearRegression.class */
public class TrainLinearRegression extends BasicTraining {
    private final LinearRegression method;
    private final MLDataSet training;

    public TrainLinearRegression(LinearRegression linearRegression, MLDataSet mLDataSet) {
        super(linearRegression.getInputCount() == 1 ? TrainingImplementationType.OnePass : TrainingImplementationType.Iterative);
        this.method = linearRegression;
        this.training = mLDataSet;
    }

    @Override // org.encogx.ml.train.BasicTraining, org.encogx.ml.train.MLTrain
    public MLDataSet getTraining() {
        return this.training;
    }

    @Override // org.encogx.ml.train.MLTrain
    public void iteration() {
        int recordCount = (int) this.training.getRecordCount();
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (MLDataPair mLDataPair : this.training) {
            d += mLDataPair.getInputArray()[0];
            d2 += mLDataPair.getIdealArray()[0];
            d4 += Math.pow(mLDataPair.getInputArray()[0], 2.0d);
            d3 += mLDataPair.getInputArray()[0] * mLDataPair.getIdealArray()[0];
        }
        this.method.getWeights()[1] = ((recordCount * d3) - (d * d2)) / ((recordCount * d4) - Math.pow(d, 2.0d));
        this.method.getWeights()[0] = ((1.0d / recordCount) * d2) - ((this.method.getWeights()[1] / recordCount) * d);
        setError(EncogUtility.calculateRegressionError(this.method, this.training));
    }

    @Override // org.encogx.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    @Override // org.encogx.ml.train.MLTrain
    public TrainingContinuation pause() {
        return null;
    }

    @Override // org.encogx.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
        throw new EncogError("Not supported");
    }

    @Override // org.encogx.ml.train.MLTrain
    public MLMethod getMethod() {
        return this.method;
    }
}
