package org.encogx.ensemble;

import org.encogx.ensemble.data.EnsembleDataSet;
import org.encogx.ml.MLMethod;
import org.encogx.ml.data.MLData;
import org.encogx.ml.train.MLTrain;
import org.encogx.neural.networks.BasicNetwork;
import org.encogx.util.EngineArray;

/* loaded from: input_file:org/encogx/ensemble/GenericEnsembleML.class */
public class GenericEnsembleML implements EnsembleML {
    private EnsembleDataSet trainingSet;
    private BasicNetwork ml;
    private MLTrain trainer;
    private String label;

    public GenericEnsembleML(MLMethod mLMethod, String str) {
        setMl(mLMethod);
        this.label = str;
    }

    @Override // org.encogx.ensemble.EnsembleML
    public void setTrainingSet(EnsembleDataSet ensembleDataSet) {
        this.trainingSet = ensembleDataSet;
    }

    @Override // org.encogx.ensemble.EnsembleML
    public EnsembleDataSet getTrainingSet() {
        return this.trainingSet;
    }

    @Override // org.encogx.ensemble.EnsembleML
    public void train(double d, boolean z) {
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        int i = 0;
        do {
            this.trainer.iteration();
            i++;
            if (i > 1) {
                d3 = d2;
            }
            d2 = this.trainer.getError();
            if (i > 1) {
                d4 = d3 - d2;
            }
            if (z) {
                System.out.println(String.valueOf(i) + " " + d2);
            }
            if (d2 <= d || !this.trainer.canContinue() || d4 <= -0.1d) {
                break;
            }
        } while (i < 2000);
        this.trainer.finishTraining();
    }

    @Override // org.encogx.ensemble.EnsembleML
    public void setMl(MLMethod mLMethod) {
        this.ml = (BasicNetwork) mLMethod;
    }

    @Override // org.encogx.ensemble.EnsembleML
    public MLMethod getMl() {
        return this.ml;
    }

    @Override // org.encogx.ml.MLClassification
    public int classify(MLData mLData) {
        return this.ml.classify(mLData);
    }

    @Override // org.encogx.ml.MLRegression
    public MLData compute(MLData mLData) {
        return this.ml.compute(mLData);
    }

    @Override // org.encogx.ml.MLInput
    public int getInputCount() {
        return this.ml.getInputCount();
    }

    @Override // org.encogx.ml.MLOutput
    public int getOutputCount() {
        return this.ml.getOutputCount();
    }

    @Override // org.encogx.ensemble.EnsembleML
    public void train(double d) {
        train(d, false);
    }

    public int winner(MLData mLData) {
        return EngineArray.maxIndex(mLData.getData());
    }

    @Override // org.encogx.ensemble.EnsembleML
    public void setTraining(MLTrain mLTrain) {
        this.trainer = mLTrain;
    }

    @Override // org.encogx.ensemble.EnsembleML
    public MLTrain getTraining() {
        return this.trainer;
    }

    @Override // org.encogx.ensemble.EnsembleML
    public void trainStep() {
        this.trainer.iteration();
    }

    @Override // org.encogx.ensemble.EnsembleML
    public String getLabel() {
        return this.label;
    }

    @Override // org.encogx.ensemble.EnsembleML
    public double getError(EnsembleDataSet ensembleDataSet) {
        return this.ml.calculateError(ensembleDataSet);
    }
}
