package neural;

import java.util.ArrayList;
import java.util.Iterator;

/* loaded from: input_file:neural/RProp.class */
public class RProp implements ILearningAlgorithm {
    private static final double DEFAULT_INITIAL_STEP_SIZE = 0.1d;
    private double initialStepSize;
    private double etaMinus;
    private double etaPlus;
    private double stepSizeMin;
    private double stepSizeMax;
    private NeuralNetwork network;
    private TrainModel model;

    public RProp(NeuralNetwork neuralNetwork) {
        this.etaMinus = 0.5d;
        this.etaPlus = 1.2d;
        this.stepSizeMin = 1.0E-6d;
        this.stepSizeMax = 50.0d;
        this.network = neuralNetwork;
        this.model = TrainModel.batch;
        setInitialStepSize(neuralNetwork.synapses());
    }

    public RProp(NeuralNetwork neuralNetwork, double d, double d2) {
        this.etaMinus = 0.5d;
        this.etaPlus = 1.2d;
        this.stepSizeMin = 1.0E-6d;
        this.stepSizeMax = 50.0d;
        this.etaMinus = d;
        this.etaPlus = d2;
        this.model = TrainModel.batch;
        setInitialStepSize(neuralNetwork.synapses());
    }

    public RProp(ArrayList<Synapse> arrayList) {
        this.etaMinus = 0.5d;
        this.etaPlus = 1.2d;
        this.stepSizeMin = 1.0E-6d;
        this.stepSizeMax = 50.0d;
        setInitialStepSize(arrayList);
        this.model = TrainModel.batch;
    }

    @Override // neural.ILearningAlgorithm
    public String getType() {
        return "RProp";
    }

    public void setInitialStepSize(ArrayList<Synapse> arrayList) {
        Iterator<Synapse> it = arrayList.iterator();
        while (it.hasNext()) {
            it.next().setStepSize(0.1d);
        }
    }

    public void setModel(TrainModel trainModel) {
        this.model = trainModel;
    }

    @Override // neural.ILearningAlgorithm
    public void train(TrainingSet trainingSet, SlopeCalcParams slopeCalcParams, ISlopeCalcFunction iSlopeCalcFunction) throws Exception {
        if (this.model == TrainModel.batch) {
            iSlopeCalcFunction.calculateSlope(slopeCalcParams, trainingSet);
            modifyWeights(slopeCalcParams.mode, slopeCalcParams.synapsesToTrain);
        } else {
            Iterator<TrainingPattern> it = trainingSet.getTraningSet().iterator();
            while (it.hasNext()) {
                iSlopeCalcFunction.calculateSlope(slopeCalcParams, it.next());
                modifyWeights(slopeCalcParams.mode, slopeCalcParams.synapsesToTrain);
            }
        }
    }

    @Override // neural.ILearningAlgorithm
    public void modifyWeights(TrainMode trainMode, ArrayList<Synapse> arrayList) throws Exception {
        Iterator<Synapse> it = arrayList.iterator();
        while (it.hasNext()) {
            Synapse next = it.next();
            if (next.currentSlope() * next.previousSlope() > 0.0d) {
                double stepSize = next.stepSize() * this.etaPlus;
                if (stepSize > this.stepSizeMax) {
                    stepSize = this.stepSizeMax;
                }
                next.setStepSize(stepSize);
                try {
                    next.addWeight((-Math.signum(next.currentSlope())) * stepSize);
                } catch (Exception e) {
                    throw new Exception("RProp: modifyWeights -> " + e.getMessage());
                }
            } else if (next.currentSlope() * next.previousSlope() < 0.0d) {
                double stepSize2 = next.stepSize() * this.etaMinus;
                if (stepSize2 < this.stepSizeMin) {
                    stepSize2 = this.stepSizeMin;
                }
                next.setStepSize(stepSize2);
                next.setCurrentSlope(0.0d);
            } else {
                try {
                    next.addWeight((-Math.signum(next.currentSlope())) * next.stepSize());
                } catch (Exception e2) {
                    throw new Exception("RProp: modifyWeights -> " + e2.getMessage());
                }
            }
        }
    }
}
