package org.encogx.neural.networks.training.propagation.quick;

import org.encogx.EncogError;
import org.encogx.ml.data.MLDataSet;
import org.encogx.neural.flat.FlatNetwork;
import org.encogx.neural.networks.ContainsFlat;
import org.encogx.neural.networks.training.LearningRate;
import org.encogx.neural.networks.training.TrainingError;
import org.encogx.neural.networks.training.propagation.Propagation;
import org.encogx.neural.networks.training.propagation.TrainingContinuation;
import org.encogx.util.EngineArray;
import org.encogx.util.validate.ValidateNetwork;

/* loaded from: input_file:org/encogx/neural/networks/training/propagation/quick/QuickPropagation.class */
public class QuickPropagation extends Propagation implements LearningRate {
    public static final String LAST_GRADIENTS = "LAST_GRADIENTS";
    private double learningRate;
    private double[] lastDelta;
    private double decay;
    private double eps;
    private double outputEpsilon;
    private double shrink;

    public QuickPropagation(ContainsFlat containsFlat, MLDataSet mLDataSet) {
        this(containsFlat, mLDataSet, 2.0d);
    }

    public QuickPropagation(ContainsFlat containsFlat, MLDataSet mLDataSet, double d) {
        super(containsFlat, mLDataSet);
        this.decay = 1.0E-4d;
        this.outputEpsilon = 0.35d;
        ValidateNetwork.validateMethodToData(containsFlat, mLDataSet);
        this.learningRate = d;
        this.lastDelta = new double[this.network.getFlat().getWeights().length];
    }

    @Override // org.encogx.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    public double[] getLastDelta() {
        return this.lastDelta;
    }

    @Override // org.encogx.neural.networks.training.LearningRate
    public double getLearningRate() {
        return this.learningRate;
    }

    public boolean isValidResume(TrainingContinuation trainingContinuation) {
        return trainingContinuation.getContents().containsKey("LAST_GRADIENTS") && trainingContinuation.getTrainingType().equals(getClass().getSimpleName()) && ((double[]) trainingContinuation.get("LAST_GRADIENTS")).length == ((ContainsFlat) getMethod()).getFlat().getWeights().length;
    }

    @Override // org.encogx.ml.train.MLTrain
    public TrainingContinuation pause() {
        TrainingContinuation trainingContinuation = new TrainingContinuation();
        trainingContinuation.setTrainingType(getClass().getSimpleName());
        trainingContinuation.set("LAST_GRADIENTS", getLastGradient());
        return trainingContinuation;
    }

    @Override // org.encogx.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
        if (!isValidResume(trainingContinuation)) {
            throw new TrainingError("Invalid training resume data length");
        }
        EngineArray.arrayCopy((double[]) trainingContinuation.get("LAST_GRADIENTS"), getLastGradient());
    }

    @Override // org.encogx.neural.networks.training.LearningRate
    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public double getOutputEpsilon() {
        return this.outputEpsilon;
    }

    public double getShrink() {
        return this.shrink;
    }

    public void setShrink(double d) {
        this.shrink = d;
    }

    public void setOutputEpsilon(double d) {
        this.outputEpsilon = d;
    }

    @Override // org.encogx.neural.networks.training.propagation.Propagation
    public void initOthers() {
        this.eps = this.outputEpsilon / getTraining().getRecordCount();
        this.shrink = this.learningRate / (1.0d + this.learningRate);
    }

    @Override // org.encogx.neural.networks.training.propagation.Propagation
    public double updateWeight(double[] dArr, double[] dArr2, int i) {
        double d;
        double d2 = this.network.getFlat().getWeights()[i];
        double d3 = this.lastDelta[i];
        double d4 = (-this.gradients[i]) + (this.decay * d2);
        double d5 = -dArr2[i];
        double d6 = 0.0d;
        if (d3 < FlatNetwork.NO_BIAS_ACTIVATION) {
            if (d4 > FlatNetwork.NO_BIAS_ACTIVATION) {
                d6 = FlatNetwork.NO_BIAS_ACTIVATION - (this.eps * d4);
            }
            d = d4 >= this.shrink * d5 ? d6 + (this.learningRate * d3) : d6 + ((d3 * d4) / (d5 - d4));
        } else if (d3 > FlatNetwork.NO_BIAS_ACTIVATION) {
            if (d4 < FlatNetwork.NO_BIAS_ACTIVATION) {
                d6 = FlatNetwork.NO_BIAS_ACTIVATION - (this.eps * d4);
            }
            d = d4 <= this.shrink * d5 ? d6 + (this.learningRate * d3) : d6 + ((d3 * d4) / (d5 - d4));
        } else {
            d = FlatNetwork.NO_BIAS_ACTIVATION - (this.eps * d4);
        }
        this.lastDelta[i] = d;
        getLastGradient()[i] = dArr[i];
        return d;
    }

    @Override // org.encogx.neural.networks.training.propagation.Propagation, org.encogx.neural.networks.training.BatchSize
    public void setBatchSize(int i) {
        if (i != 0) {
            throw new EncogError("Online training is not supported for:" + getClass().getSimpleName());
        }
    }
}
