package org.encogx.neural.networks.training.simple;

import org.encogx.mathutil.error.ErrorCalculation;
import org.encogx.ml.MLMethod;
import org.encogx.ml.TrainingImplementationType;
import org.encogx.ml.data.MLData;
import org.encogx.ml.data.MLDataPair;
import org.encogx.ml.data.MLDataSet;
import org.encogx.ml.train.BasicTraining;
import org.encogx.neural.NeuralNetworkError;
import org.encogx.neural.networks.BasicNetwork;
import org.encogx.neural.networks.training.LearningRate;
import org.encogx.neural.networks.training.propagation.TrainingContinuation;

/* loaded from: input_file:org/encogx/neural/networks/training/simple/TrainAdaline.class */
public class TrainAdaline extends BasicTraining implements LearningRate {
    private final BasicNetwork network;
    private final MLDataSet training;
    private double learningRate;

    public TrainAdaline(BasicNetwork basicNetwork, MLDataSet mLDataSet, double d) {
        super(TrainingImplementationType.Iterative);
        if (basicNetwork.getLayerCount() > 2) {
            throw new NeuralNetworkError("An ADALINE network only has two layers.");
        }
        this.network = basicNetwork;
        this.training = mLDataSet;
        this.learningRate = d;
    }

    @Override // org.encogx.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    @Override // org.encogx.neural.networks.training.LearningRate
    public double getLearningRate() {
        return this.learningRate;
    }

    @Override // org.encogx.ml.train.MLTrain
    public MLMethod getMethod() {
        return this.network;
    }

    @Override // org.encogx.ml.train.MLTrain
    public void iteration() {
        ErrorCalculation errorCalculation = new ErrorCalculation();
        for (MLDataPair mLDataPair : this.training) {
            MLData compute = this.network.compute(mLDataPair.getInput());
            for (int i = 0; i < compute.size(); i++) {
                double data = mLDataPair.getIdeal().getData(i) - compute.getData(i);
                int i2 = 0;
                while (i2 <= this.network.getInputCount()) {
                    this.network.addWeight(0, i2, i, this.learningRate * data * (i2 == this.network.getInputCount() ? 1.0d : mLDataPair.getInput().getData(i2)));
                    i2++;
                }
            }
            errorCalculation.updateError(compute.getData(), mLDataPair.getIdeal().getData(), mLDataPair.getSignificance());
        }
        setError(errorCalculation.calculate());
    }

    @Override // org.encogx.ml.train.MLTrain
    public TrainingContinuation pause() {
        return null;
    }

    @Override // org.encogx.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
    }

    @Override // org.encogx.neural.networks.training.LearningRate
    public void setLearningRate(double d) {
        this.learningRate = d;
    }
}
