package org.encogx.neural.networks.training.propagation.back;

import org.encogx.ml.data.MLDataSet;
import org.encogx.neural.flat.FlatNetwork;
import org.encogx.neural.networks.ContainsFlat;
import org.encogx.neural.networks.training.LearningRate;
import org.encogx.neural.networks.training.Momentum;
import org.encogx.neural.networks.training.TrainingError;
import org.encogx.neural.networks.training.propagation.Propagation;
import org.encogx.neural.networks.training.propagation.TrainingContinuation;
import org.encogx.neural.networks.training.strategy.SmartLearningRate;
import org.encogx.neural.networks.training.strategy.SmartMomentum;
import org.encogx.util.validate.ValidateNetwork;

/* loaded from: input_file:org/encogx/neural/networks/training/propagation/back/Backpropagation.class */
public class Backpropagation extends Propagation implements Momentum, LearningRate {
    public static final String LAST_DELTA = "LAST_DELTA";
    private double learningRate;
    private double momentum;
    private double[] lastDelta;

    public Backpropagation(ContainsFlat containsFlat, MLDataSet mLDataSet) {
        this(containsFlat, mLDataSet, FlatNetwork.NO_BIAS_ACTIVATION, FlatNetwork.NO_BIAS_ACTIVATION);
        addStrategy(new SmartLearningRate());
        addStrategy(new SmartMomentum());
    }

    public Backpropagation(ContainsFlat containsFlat, MLDataSet mLDataSet, double d, double d2) {
        super(containsFlat, mLDataSet);
        ValidateNetwork.validateMethodToData(containsFlat, mLDataSet);
        this.momentum = d2;
        this.learningRate = d;
        this.lastDelta = new double[containsFlat.getFlat().getWeights().length];
    }

    @Override // org.encogx.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    public double[] getLastDelta() {
        return this.lastDelta;
    }

    @Override // org.encogx.neural.networks.training.LearningRate
    public double getLearningRate() {
        return this.learningRate;
    }

    @Override // org.encogx.neural.networks.training.Momentum
    public double getMomentum() {
        return this.momentum;
    }

    public boolean isValidResume(TrainingContinuation trainingContinuation) {
        return trainingContinuation.getContents().containsKey(LAST_DELTA) && trainingContinuation.getTrainingType().equals(getClass().getSimpleName()) && ((double[]) trainingContinuation.get(LAST_DELTA)).length == ((ContainsFlat) getMethod()).getFlat().getWeights().length;
    }

    @Override // org.encogx.ml.train.MLTrain
    public TrainingContinuation pause() {
        TrainingContinuation trainingContinuation = new TrainingContinuation();
        trainingContinuation.setTrainingType(getClass().getSimpleName());
        trainingContinuation.set(LAST_DELTA, this.lastDelta);
        return trainingContinuation;
    }

    @Override // org.encogx.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
        if (!isValidResume(trainingContinuation)) {
            throw new TrainingError("Invalid training resume data length");
        }
        this.lastDelta = (double[]) trainingContinuation.get(LAST_DELTA);
    }

    @Override // org.encogx.neural.networks.training.LearningRate
    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    @Override // org.encogx.neural.networks.training.Momentum
    public void setMomentum(double d) {
        this.momentum = d;
    }

    @Override // org.encogx.neural.networks.training.propagation.Propagation
    public double updateWeight(double[] dArr, double[] dArr2, int i) {
        double d = (dArr[i] * this.learningRate) + (this.lastDelta[i] * this.momentum);
        this.lastDelta[i] = d;
        return d;
    }

    @Override // org.encogx.neural.networks.training.propagation.Propagation
    public void initOthers() {
    }
}
