package neural;

import java.util.ArrayList;
import java.util.Iterator;

/* loaded from: input_file:neural/BackPropagation.class */
public class BackPropagation implements ILearningAlgorithm {
    private static final double DEFAULT_LEARNING_RATE = 1.0d;
    private static final double DEFAULT_MOMENTUM_RATE = 0.1d;
    private double learningRate;
    private double momentumRate;
    private NeuralNetwork neuralNetwork;
    private TrainModel model;

    public BackPropagation(NeuralNetwork neuralNetwork) {
        this.model = TrainModel.batch;
        this.neuralNetwork = neuralNetwork;
        this.learningRate = 1.0d;
        this.momentumRate = 0.1d;
    }

    public BackPropagation(NeuralNetwork neuralNetwork, double d, double d2) {
        this.model = TrainModel.batch;
        this.neuralNetwork = neuralNetwork;
        this.learningRate = d;
        this.momentumRate = d2;
    }

    @Override // neural.ILearningAlgorithm
    public String getType() {
        return "Backpropagation";
    }

    public double learningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public double momentumRate() {
        return this.momentumRate;
    }

    public void setMomentumRate(double d) {
        this.momentumRate = d;
    }

    @Override // neural.ILearningAlgorithm
    public void modifyWeights(TrainMode trainMode, ArrayList<Synapse> arrayList) throws Exception {
        double currentSlope;
        double d;
        double lastWeightChange;
        Iterator<Synapse> it = arrayList.iterator();
        while (it.hasNext()) {
            Synapse next = it.next();
            if (trainMode == TrainMode.minimize) {
                currentSlope = (-this.learningRate) * next.currentSlope();
                d = this.momentumRate;
                lastWeightChange = next.lastWeightChange();
            } else {
                currentSlope = this.learningRate * next.currentSlope();
                d = this.momentumRate;
                lastWeightChange = next.lastWeightChange();
            }
            try {
                next.addWeight(currentSlope + (d * lastWeightChange));
            } catch (Exception e) {
                throw new Exception("BackPropagation: modifyWeightgs -> " + e.getMessage());
            }
        }
    }

    @Override // neural.ILearningAlgorithm
    public void train(TrainingSet trainingSet, SlopeCalcParams slopeCalcParams, ISlopeCalcFunction iSlopeCalcFunction) throws Exception {
        if (this.model == TrainModel.batch) {
            iSlopeCalcFunction.calculateSlope(slopeCalcParams, trainingSet);
            modifyWeights(slopeCalcParams.mode, slopeCalcParams.synapsesToTrain);
        } else {
            Iterator<TrainingPattern> it = trainingSet.getTraningSet().iterator();
            while (it.hasNext()) {
                iSlopeCalcFunction.calculateSlope(slopeCalcParams, it.next());
                modifyWeights(slopeCalcParams.mode, slopeCalcParams.synapsesToTrain);
            }
        }
    }
}
