package org.encogx.neural.freeform.training;

import java.io.Serializable;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import org.encogx.engine.network.activation.ActivationSigmoid;
import org.encogx.mathutil.error.ErrorCalculation;
import org.encogx.ml.MLMethod;
import org.encogx.ml.TrainingImplementationType;
import org.encogx.ml.data.MLData;
import org.encogx.ml.data.MLDataPair;
import org.encogx.ml.data.MLDataSet;
import org.encogx.ml.train.BasicTraining;
import org.encogx.neural.flat.FlatNetwork;
import org.encogx.neural.freeform.FreeformConnection;
import org.encogx.neural.freeform.FreeformNetwork;
import org.encogx.neural.freeform.FreeformNeuron;
import org.encogx.neural.freeform.task.ConnectionTask;

/* loaded from: input_file:org/encogx/neural/freeform/training/FreeformPropagationTraining.class */
public abstract class FreeformPropagationTraining extends BasicTraining implements Serializable {
    private static final long serialVersionUID = 1;
    public static final double FLAT_SPOT_CONST = 0.1d;
    private final FreeformNetwork network;
    private final MLDataSet training;
    private int iterationCount;
    private double error;
    private final Set<FreeformNeuron> visited;
    private boolean fixFlatSopt;
    private int batchSize;

    public FreeformPropagationTraining() {
        super(TrainingImplementationType.Iterative);
        this.visited = new HashSet();
        this.fixFlatSopt = true;
        this.batchSize = 0;
        this.network = null;
        this.training = null;
    }

    public FreeformPropagationTraining(FreeformNetwork freeformNetwork, MLDataSet mLDataSet) {
        super(TrainingImplementationType.Iterative);
        this.visited = new HashSet();
        this.fixFlatSopt = true;
        this.batchSize = 0;
        this.network = freeformNetwork;
        this.training = mLDataSet;
    }

    private void calculateNeuronGradient(FreeformNeuron freeformNeuron) {
        if (freeformNeuron.getInputSummation() != null) {
            for (FreeformConnection freeformConnection : freeformNeuron.getInputSummation().list()) {
                freeformConnection.addTempTraining(0, freeformConnection.getSource().getActivation() * freeformNeuron.getTempTraining(0));
                FreeformNeuron source = freeformConnection.getSource();
                double d = 0.0d;
                for (FreeformConnection freeformConnection2 : source.getOutputs()) {
                    d += freeformConnection2.getTarget().getTempTraining(0) * freeformConnection2.getWeight();
                }
                double derivativeFunction = freeformNeuron.getInputSummation().getActivationFunction().derivativeFunction(source.getSum(), source.getActivation());
                if (this.fixFlatSopt && (freeformNeuron.getInputSummation().getActivationFunction() instanceof ActivationSigmoid)) {
                    derivativeFunction += 0.1d;
                }
                source.setTempTraining(0, d * derivativeFunction);
            }
            Iterator<FreeformConnection> it = freeformNeuron.getInputSummation().list().iterator();
            while (it.hasNext()) {
                calculateNeuronGradient(it.next().getSource());
            }
        }
    }

    private void calculateOutputDelta(FreeformNeuron freeformNeuron, double d) {
        double activation = freeformNeuron.getActivation();
        double derivativeFunction = freeformNeuron.getInputSummation().getActivationFunction().derivativeFunction(freeformNeuron.getInputSummation().getSum(), activation);
        if (this.fixFlatSopt && (freeformNeuron.getInputSummation().getActivationFunction() instanceof ActivationSigmoid)) {
            derivativeFunction += 0.1d;
        }
        freeformNeuron.setTempTraining(0, derivativeFunction * d);
    }

    @Override // org.encogx.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    @Override // org.encogx.ml.train.BasicTraining, org.encogx.ml.train.MLTrain
    public void finishTraining() {
        this.network.tempTrainingClear();
    }

    @Override // org.encogx.ml.train.BasicTraining, org.encogx.ml.train.MLTrain
    public double getError() {
        return this.error;
    }

    @Override // org.encogx.ml.train.BasicTraining, org.encogx.ml.train.MLTrain
    public TrainingImplementationType getImplementationType() {
        return TrainingImplementationType.Iterative;
    }

    @Override // org.encogx.ml.train.BasicTraining, org.encogx.ml.train.MLTrain
    public int getIteration() {
        return this.iterationCount;
    }

    @Override // org.encogx.ml.train.MLTrain
    public MLMethod getMethod() {
        return this.network;
    }

    @Override // org.encogx.ml.train.BasicTraining, org.encogx.ml.train.MLTrain
    public MLDataSet getTraining() {
        return this.training;
    }

    public boolean isFixFlatSopt() {
        return this.fixFlatSopt;
    }

    @Override // org.encogx.ml.train.MLTrain
    public void iteration() {
        preIteration();
        this.iterationCount++;
        this.network.clearContext();
        if (this.batchSize == 0) {
            processPureBatch();
        } else {
            processBatches();
        }
        postIteration();
    }

    @Override // org.encogx.ml.train.BasicTraining, org.encogx.ml.train.MLTrain
    public void iteration(int i) {
        for (int i2 = 0; i2 < i; i2++) {
            iteration();
        }
    }

    protected void processPureBatch() {
        ErrorCalculation errorCalculation = new ErrorCalculation();
        this.visited.clear();
        for (MLDataPair mLDataPair : this.training) {
            MLData input = mLDataPair.getInput();
            MLData ideal = mLDataPair.getIdeal();
            MLData compute = this.network.compute(input);
            double significance = mLDataPair.getSignificance();
            errorCalculation.updateError(compute.getData(), ideal.getData(), significance);
            for (int i = 0; i < this.network.getOutputCount(); i++) {
                double data = (ideal.getData(i) - compute.getData(i)) * significance;
                FreeformNeuron freeformNeuron = this.network.getOutputLayer().getNeurons().get(i);
                calculateOutputDelta(freeformNeuron, data);
                calculateNeuronGradient(freeformNeuron);
            }
        }
        setError(errorCalculation.calculate());
        learn();
    }

    protected void processBatches() {
        int i = 0;
        ErrorCalculation errorCalculation = new ErrorCalculation();
        this.visited.clear();
        for (MLDataPair mLDataPair : this.training) {
            MLData input = mLDataPair.getInput();
            MLData ideal = mLDataPair.getIdeal();
            MLData compute = this.network.compute(input);
            double significance = mLDataPair.getSignificance();
            errorCalculation.updateError(compute.getData(), ideal.getData(), significance);
            for (int i2 = 0; i2 < this.network.getOutputCount(); i2++) {
                double data = (ideal.getData(i2) - compute.getData(i2)) * significance;
                FreeformNeuron freeformNeuron = this.network.getOutputLayer().getNeurons().get(i2);
                calculateOutputDelta(freeformNeuron, data);
                calculateNeuronGradient(freeformNeuron);
            }
            i++;
            if (i >= this.batchSize) {
                i = 0;
                learn();
            }
        }
        if (i > 0) {
            learn();
        }
        setError(errorCalculation.calculate());
    }

    protected void learn() {
        this.network.performConnectionTask(new ConnectionTask() { // from class: org.encogx.neural.freeform.training.FreeformPropagationTraining.1
            @Override // org.encogx.neural.freeform.task.ConnectionTask
            public void task(FreeformConnection freeformConnection) {
                FreeformPropagationTraining.this.learnConnection(freeformConnection);
                freeformConnection.setTempTraining(0, FlatNetwork.NO_BIAS_ACTIVATION);
            }
        });
    }

    protected abstract void learnConnection(FreeformConnection freeformConnection);

    @Override // org.encogx.ml.train.BasicTraining, org.encogx.ml.train.MLTrain
    public void setError(double d) {
        this.error = d;
    }

    public void setFixFlatSopt(boolean z) {
        this.fixFlatSopt = z;
    }

    @Override // org.encogx.ml.train.BasicTraining, org.encogx.ml.train.MLTrain
    public void setIteration(int i) {
        this.iterationCount = i;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setBatchSize(int i) {
        this.batchSize = i;
    }
}
