package org.encog.neural.networks.training.propagation;

import org.encog.engine.network.activation.ActivationFunction;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.neural.error.ErrorFunction;
import org.encog.neural.flat.FlatNetwork;
import org.encog.util.EngineArray;
import org.encog.util.concurrency.EngineTask;

/* loaded from: input_file:org/encog/neural/networks/training/propagation/GradientWorker.class */
public class GradientWorker implements EngineTask {
    private final FlatNetwork network;
    private final ErrorCalculation errorCalculation = new ErrorCalculation();
    private final double[] actual;
    private final double[] layerDelta;
    private final int[] layerCounts;
    private final int[] layerFeedCounts;
    private final int[] layerIndex;
    private final int[] weightIndex;
    private final double[] layerOutput;
    private final double[] layerSums;
    private final double[] gradients;
    private final double[] weights;
    private final MLDataPair pair;
    private final MLDataSet training;
    private final int low;
    private final int high;
    private final Propagation owner;
    private double[] flatSpot;
    private final ErrorFunction errorFunction;

    public GradientWorker(FlatNetwork flatNetwork, Propagation propagation, MLDataSet mLDataSet, int i, int i2, double[] dArr, ErrorFunction errorFunction) {
        this.network = flatNetwork;
        this.training = mLDataSet;
        this.low = i;
        this.high = i2;
        this.owner = propagation;
        this.flatSpot = dArr;
        this.errorFunction = errorFunction;
        this.layerDelta = new double[this.network.getLayerOutput().length];
        this.gradients = new double[this.network.getWeights().length];
        this.actual = new double[this.network.getOutputCount()];
        this.weights = this.network.getWeights();
        this.layerIndex = this.network.getLayerIndex();
        this.layerCounts = this.network.getLayerCounts();
        this.weightIndex = this.network.getWeightIndex();
        this.layerOutput = this.network.getLayerOutput();
        this.layerSums = this.network.getLayerSums();
        this.layerFeedCounts = this.network.getLayerFeedCounts();
        this.pair = BasicMLDataPair.createPair(this.network.getInputCount(), this.network.getOutputCount());
    }

    public FlatNetwork getNetwork() {
        return this.network;
    }

    public double[] getWeights() {
        return this.weights;
    }

    private void process(MLDataPair mLDataPair) {
        this.network.compute(mLDataPair.getInputArray(), this.actual);
        this.errorCalculation.updateError(this.actual, mLDataPair.getIdealArray(), mLDataPair.getSignificance());
        this.errorFunction.calculateError(mLDataPair.getIdealArray(), this.actual, this.layerDelta);
        for (int i = 0; i < this.actual.length; i++) {
            this.layerDelta[i] = (this.network.getActivationFunctions()[0].derivativeFunction(this.layerSums[i], this.layerOutput[i]) + this.flatSpot[0]) * this.layerDelta[i] * mLDataPair.getSignificance();
        }
        for (int beginTraining = this.network.getBeginTraining(); beginTraining < this.network.getEndTraining(); beginTraining++) {
            processLevel(beginTraining);
        }
    }

    private void processLevel(int i) {
        int i2 = this.layerIndex[i + 1];
        int i3 = this.layerIndex[i];
        int i4 = this.layerCounts[i + 1];
        int i5 = this.layerFeedCounts[i];
        int i6 = this.weightIndex[i];
        ActivationFunction activationFunction = this.network.getActivationFunctions()[i];
        double d = this.flatSpot[i + 1];
        double[] dArr = this.layerDelta;
        double[] dArr2 = this.weights;
        double[] dArr3 = this.gradients;
        double[] dArr4 = this.layerOutput;
        double[] dArr5 = this.layerSums;
        int i7 = i2;
        for (int i8 = 0; i8 < i4; i8++) {
            double d2 = dArr4[i7];
            double d3 = 0.0d;
            int i9 = i6 + i8;
            int i10 = i3 + i5;
            int i11 = i3;
            while (i11 < i10) {
                int i12 = i9;
                dArr3[i12] = dArr3[i12] + (d2 * dArr[i11]);
                d3 += dArr2[i9] * dArr[i11];
                i11++;
                i9 += i4;
            }
            dArr[i7] = d3 * (activationFunction.derivativeFunction(dArr5[i7], dArr4[i7]) + d);
            i7++;
        }
    }

    @Override // org.encog.util.concurrency.EngineTask
    public final void run() {
        try {
            this.errorCalculation.reset();
            for (int i = this.low; i <= this.high; i++) {
                this.training.getRecord(i, this.pair);
                process(this.pair);
            }
            this.owner.report(this.gradients, this.errorCalculation.calculate(), null);
            EngineArray.fill(this.gradients, 0.0d);
        } catch (Throwable th) {
            this.owner.report(null, 0.0d, th);
        }
    }

    public final void run(int i) {
        this.training.getRecord(i, this.pair);
        process(this.pair);
        this.owner.report(this.gradients, 0.0d, null);
        EngineArray.fill(this.gradients, 0.0d);
    }

    public ErrorCalculation getErrorCalculation() {
        return this.errorCalculation;
    }
}
