package org.encogx.neural.networks.training.nm;

import org.encogx.ml.MLMethod;
import org.encogx.ml.TrainingImplementationType;
import org.encogx.ml.data.MLDataSet;
import org.encogx.ml.train.BasicTraining;
import org.encogx.neural.flat.FlatNetwork;
import org.encogx.neural.networks.BasicNetwork;
import org.encogx.neural.networks.structure.NetworkCODEC;
import org.encogx.neural.networks.training.propagation.TrainingContinuation;
import org.encogx.util.EngineArray;

/* loaded from: input_file:org/encogx/neural/networks/training/nm/NelderMeadTraining.class */
public class NelderMeadTraining extends BasicTraining {
    private final BasicNetwork network;
    private double ynewlo;
    private boolean converged;
    private final double ccoeff = 0.5d;
    private double del;
    private final double ecoeff = 2.0d;
    private final double eps = 0.001d;
    private int ihi;
    private int ilo;
    private int jcount;
    private int l;
    private final int nn;
    private final double[] p;
    private final double[] p2star;
    private final double[] pbar;
    private final double[] pstar;
    private final double rcoeff = 1.0d;
    private final double rq;
    private final double[] y;
    private double y2star;
    private double ylo;
    private double ystar;
    private double z;
    private final double[] start;
    private final double[] trainedWeights;
    private final double[] step;
    private int konvge;

    public NelderMeadTraining(BasicNetwork basicNetwork, MLDataSet mLDataSet) {
        this(basicNetwork, mLDataSet, 100.0d);
    }

    public NelderMeadTraining(BasicNetwork basicNetwork, MLDataSet mLDataSet, double d) {
        super(TrainingImplementationType.OnePass);
        this.converged = false;
        this.ccoeff = 0.5d;
        this.ecoeff = 2.0d;
        this.eps = 0.001d;
        this.rcoeff = 1.0d;
        this.network = basicNetwork;
        setTraining(mLDataSet);
        this.start = NetworkCODEC.networkToArray(basicNetwork);
        this.trainedWeights = NetworkCODEC.networkToArray(basicNetwork);
        int length = this.start.length;
        this.p = new double[length * (length + 1)];
        this.pstar = new double[length];
        this.p2star = new double[length];
        this.pbar = new double[length];
        this.y = new double[length + 1];
        this.nn = length + 1;
        this.del = 1.0d;
        this.rq = 1.0E-13d * length;
        this.step = new double[NetworkCODEC.networkSize(basicNetwork)];
        this.konvge = 500;
        this.jcount = 500;
        EngineArray.fill(this.step, d);
    }

    @Override // org.encogx.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    public double fn(double[] dArr) {
        NetworkCODEC.arrayToNetwork(dArr, this.network);
        return this.network.calculateError(getTraining());
    }

    @Override // org.encogx.ml.train.MLTrain
    public MLMethod getMethod() {
        return this.network;
    }

    @Override // org.encogx.ml.train.BasicTraining, org.encogx.ml.train.MLTrain
    public boolean isTrainingDone() {
        if (this.converged) {
            return true;
        }
        return super.isTrainingDone();
    }

    @Override // org.encogx.ml.train.MLTrain
    public void iteration() {
        if (this.converged) {
            return;
        }
        int length = this.start.length;
        for (int i = 0; i < length; i++) {
            this.p[i + (length * length)] = this.start[i];
        }
        this.y[length] = fn(this.start);
        for (int i2 = 0; i2 < length; i2++) {
            double d = this.start[i2];
            this.start[i2] = this.start[i2] + (this.step[i2] * this.del);
            for (int i3 = 0; i3 < length; i3++) {
                this.p[i3 + (i2 * length)] = this.start[i3];
            }
            this.y[i2] = fn(this.start);
            this.start[i2] = d;
        }
        this.ylo = this.y[0];
        this.ilo = 0;
        for (int i4 = 1; i4 < this.nn; i4++) {
            if (this.y[i4] < this.ylo) {
                this.ylo = this.y[i4];
                this.ilo = i4;
            }
        }
        while (true) {
            this.ynewlo = this.y[0];
            this.ihi = 0;
            for (int i5 = 1; i5 < this.nn; i5++) {
                if (this.ynewlo < this.y[i5]) {
                    this.ynewlo = this.y[i5];
                    this.ihi = i5;
                }
            }
            for (int i6 = 0; i6 < length; i6++) {
                this.z = FlatNetwork.NO_BIAS_ACTIVATION;
                for (int i7 = 0; i7 < this.nn; i7++) {
                    this.z += this.p[i6 + (i7 * length)];
                }
                this.z -= this.p[i6 + (this.ihi * length)];
                this.pbar[i6] = this.z / length;
            }
            for (int i8 = 0; i8 < length; i8++) {
                this.pstar[i8] = this.pbar[i8] + (1.0d * (this.pbar[i8] - this.p[i8 + (this.ihi * length)]));
            }
            this.ystar = fn(this.pstar);
            if (this.ystar < this.ylo) {
                for (int i9 = 0; i9 < length; i9++) {
                    this.p2star[i9] = this.pbar[i9] + (2.0d * (this.pstar[i9] - this.pbar[i9]));
                }
                this.y2star = fn(this.p2star);
                if (this.ystar < this.y2star) {
                    for (int i10 = 0; i10 < length; i10++) {
                        this.p[i10 + (this.ihi * length)] = this.pstar[i10];
                    }
                    this.y[this.ihi] = this.ystar;
                } else {
                    for (int i11 = 0; i11 < length; i11++) {
                        this.p[i11 + (this.ihi * length)] = this.p2star[i11];
                    }
                    this.y[this.ihi] = this.y2star;
                }
            } else {
                this.l = 0;
                for (int i12 = 0; i12 < this.nn; i12++) {
                    if (this.ystar < this.y[i12]) {
                        this.l++;
                    }
                }
                if (1 < this.l) {
                    for (int i13 = 0; i13 < length; i13++) {
                        this.p[i13 + (this.ihi * length)] = this.pstar[i13];
                    }
                    this.y[this.ihi] = this.ystar;
                } else if (this.l == 0) {
                    for (int i14 = 0; i14 < length; i14++) {
                        this.p2star[i14] = this.pbar[i14] + (0.5d * (this.p[i14 + (this.ihi * length)] - this.pbar[i14]));
                    }
                    this.y2star = fn(this.p2star);
                    if (this.y[this.ihi] < this.y2star) {
                        for (int i15 = 0; i15 < this.nn; i15++) {
                            for (int i16 = 0; i16 < length; i16++) {
                                this.p[i16 + (i15 * length)] = (this.p[i16 + (i15 * length)] + this.p[i16 + (this.ilo * length)]) * 0.5d;
                                this.trainedWeights[i16] = this.p[i16 + (i15 * length)];
                            }
                            this.y[i15] = fn(this.trainedWeights);
                        }
                        this.ylo = this.y[0];
                        this.ilo = 0;
                        for (int i17 = 1; i17 < this.nn; i17++) {
                            if (this.y[i17] < this.ylo) {
                                this.ylo = this.y[i17];
                                this.ilo = i17;
                            }
                        }
                    } else {
                        for (int i18 = 0; i18 < length; i18++) {
                            this.p[i18 + (this.ihi * length)] = this.p2star[i18];
                        }
                        this.y[this.ihi] = this.y2star;
                    }
                } else if (this.l == 1) {
                    for (int i19 = 0; i19 < length; i19++) {
                        this.p2star[i19] = this.pbar[i19] + (0.5d * (this.pstar[i19] - this.pbar[i19]));
                    }
                    this.y2star = fn(this.p2star);
                    if (this.y2star <= this.ystar) {
                        for (int i20 = 0; i20 < length; i20++) {
                            this.p[i20 + (this.ihi * length)] = this.p2star[i20];
                        }
                        this.y[this.ihi] = this.y2star;
                    } else {
                        for (int i21 = 0; i21 < length; i21++) {
                            this.p[i21 + (this.ihi * length)] = this.pstar[i21];
                        }
                        this.y[this.ihi] = this.ystar;
                    }
                }
            }
            if (this.y[this.ihi] < this.ylo) {
                this.ylo = this.y[this.ihi];
                this.ilo = this.ihi;
            }
            this.jcount--;
            if (this.jcount <= 0) {
                this.jcount = this.konvge;
                this.z = FlatNetwork.NO_BIAS_ACTIVATION;
                for (int i22 = 0; i22 < this.nn; i22++) {
                    this.z += this.y[i22];
                }
                double d2 = this.z / this.nn;
                this.z = FlatNetwork.NO_BIAS_ACTIVATION;
                for (int i23 = 0; i23 < this.nn; i23++) {
                    this.z += Math.pow(this.y[i23] - d2, 2.0d);
                }
                if (this.z <= this.rq) {
                    break;
                }
            } else {
                continue;
            }
        }
        for (int i24 = 0; i24 < length; i24++) {
            this.trainedWeights[i24] = this.p[i24 + (this.ilo * length)];
        }
        this.ynewlo = this.y[this.ilo];
        boolean z = false;
        int i25 = 0;
        while (true) {
            if (i25 >= length) {
                break;
            }
            this.del = this.step[i25] * 0.001d;
            double[] dArr = this.trainedWeights;
            int i26 = i25;
            dArr[i26] = dArr[i26] + this.del;
            this.z = fn(this.trainedWeights);
            if (this.z < this.ynewlo) {
                z = true;
                break;
            }
            this.trainedWeights[i25] = (this.trainedWeights[i25] - this.del) - this.del;
            this.z = fn(this.trainedWeights);
            if (this.z < this.ynewlo) {
                z = true;
                break;
            }
            double[] dArr2 = this.trainedWeights;
            int i27 = i25;
            dArr2[i27] = dArr2[i27] + this.del;
            i25++;
        }
        if (z) {
            for (int i28 = 0; i28 < length; i28++) {
                this.start[i28] = this.trainedWeights[i28];
            }
            this.del = 0.001d;
        } else {
            this.converged = true;
        }
        setError(this.ynewlo);
        NetworkCODEC.arrayToNetwork(this.trainedWeights, this.network);
    }

    @Override // org.encogx.ml.train.MLTrain
    public TrainingContinuation pause() {
        return null;
    }

    @Override // org.encogx.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
    }
}
