package org.encogx.ml.bayesian.training;

import org.encogx.ml.MLMethod;
import org.encogx.ml.TrainingImplementationType;
import org.encogx.ml.bayesian.BayesianEvent;
import org.encogx.ml.bayesian.BayesianNetwork;
import org.encogx.ml.bayesian.training.estimator.BayesEstimator;
import org.encogx.ml.bayesian.training.estimator.SimpleEstimator;
import org.encogx.ml.bayesian.training.search.k2.BayesSearch;
import org.encogx.ml.bayesian.training.search.k2.SearchK2;
import org.encogx.ml.data.MLDataSet;
import org.encogx.ml.train.BasicTraining;
import org.encogx.neural.networks.training.propagation.TrainingContinuation;

/* loaded from: input_file:org/encogx/ml/bayesian/training/TrainBayesian.class */
public class TrainBayesian extends BasicTraining {
    private Phase p;
    private final MLDataSet data;
    private final BayesianNetwork network;
    private final int maximumParents;
    private final BayesSearch search;
    private final BayesEstimator estimator;
    private BayesianInit initNetwork;
    private String holdQuery;
    private static /* synthetic */ int[] $SWITCH_TABLE$org$encogx$ml$bayesian$training$BayesianInit;
    private static /* synthetic */ int[] $SWITCH_TABLE$org$encogx$ml$bayesian$training$TrainBayesian$Phase;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/encogx/ml/bayesian/training/TrainBayesian$Phase.class */
    public enum Phase {
        Init,
        Search,
        SearchDone,
        Probability,
        Finish,
        Terminated;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static Phase[] valuesCustom() {
            Phase[] valuesCustom = values();
            int length = valuesCustom.length;
            Phase[] phaseArr = new Phase[length];
            System.arraycopy(valuesCustom, 0, phaseArr, 0, length);
            return phaseArr;
        }
    }

    public TrainBayesian(BayesianNetwork bayesianNetwork, MLDataSet mLDataSet, int i) {
        this(bayesianNetwork, mLDataSet, i, BayesianInit.InitNaiveBayes, new SearchK2(), new SimpleEstimator());
    }

    public TrainBayesian(BayesianNetwork bayesianNetwork, MLDataSet mLDataSet, int i, BayesianInit bayesianInit, BayesSearch bayesSearch, BayesEstimator bayesEstimator) {
        super(TrainingImplementationType.Iterative);
        this.p = Phase.Init;
        this.initNetwork = BayesianInit.InitNaiveBayes;
        this.network = bayesianNetwork;
        this.data = mLDataSet;
        this.maximumParents = i;
        this.search = bayesSearch;
        this.search.init(this, bayesianNetwork, mLDataSet);
        this.estimator = bayesEstimator;
        this.estimator.init(this, bayesianNetwork, mLDataSet);
        this.initNetwork = bayesianInit;
        setError(1.0d);
    }

    private void initNaiveBayes() {
        this.network.removeAllRelations();
        BayesianEvent classificationTargetEvent = this.network.getClassificationTargetEvent();
        for (BayesianEvent bayesianEvent : this.network.getEvents()) {
            if (bayesianEvent != classificationTargetEvent) {
                this.network.createDependency(classificationTargetEvent, bayesianEvent);
            }
        }
        this.network.finalizeStructure();
    }

    private void iterationInit() {
        this.holdQuery = this.network.getClassificationStructure();
        switch ($SWITCH_TABLE$org$encogx$ml$bayesian$training$BayesianInit()[this.initNetwork.ordinal()]) {
            case 2:
                this.network.removeAllRelations();
                this.network.finalizeStructure();
                break;
            case 3:
                initNaiveBayes();
                break;
        }
        this.p = Phase.Search;
    }

    private void iterationSearch() {
        if (this.search.iteration()) {
            return;
        }
        this.p = Phase.SearchDone;
    }

    private void iterationSearchDone() {
        this.network.finalizeStructure();
        this.network.reset();
        this.p = Phase.Probability;
    }

    private void iterationProbability() {
        if (this.estimator.iteration()) {
            return;
        }
        this.p = Phase.Finish;
    }

    private void iterationFinish() {
        this.network.defineClassificationStructure(this.holdQuery);
        setError(this.network.calculateError(this.data));
        this.p = Phase.Terminated;
    }

    @Override // org.encogx.ml.train.BasicTraining, org.encogx.ml.train.MLTrain
    public boolean isTrainingDone() {
        return super.isTrainingDone() || this.p == Phase.Terminated;
    }

    @Override // org.encogx.ml.train.MLTrain
    public void iteration() {
        postIteration();
        switch ($SWITCH_TABLE$org$encogx$ml$bayesian$training$TrainBayesian$Phase()[this.p.ordinal()]) {
            case 1:
                iterationInit();
                break;
            case 2:
                iterationSearch();
                break;
            case 3:
                iterationSearchDone();
                break;
            case 4:
                iterationProbability();
                break;
            case 5:
                iterationFinish();
                break;
        }
        preIteration();
    }

    @Override // org.encogx.ml.train.MLTrain
    public boolean canContinue() {
        return false;
    }

    @Override // org.encogx.ml.train.MLTrain
    public TrainingContinuation pause() {
        return null;
    }

    @Override // org.encogx.ml.train.MLTrain
    public void resume(TrainingContinuation trainingContinuation) {
    }

    @Override // org.encogx.ml.train.MLTrain
    public MLMethod getMethod() {
        return this.network;
    }

    public BayesianNetwork getNetwork() {
        return this.network;
    }

    public int getMaximumParents() {
        return this.maximumParents;
    }

    public BayesSearch getSearch() {
        return this.search;
    }

    public BayesianInit getInitNetwork() {
        return this.initNetwork;
    }

    public void setInitNetwork(BayesianInit bayesianInit) {
        this.initNetwork = bayesianInit;
    }

    static /* synthetic */ int[] $SWITCH_TABLE$org$encogx$ml$bayesian$training$BayesianInit() {
        int[] iArr = $SWITCH_TABLE$org$encogx$ml$bayesian$training$BayesianInit;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[BayesianInit.valuesCustom().length];
        try {
            iArr2[BayesianInit.InitEmpty.ordinal()] = 2;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[BayesianInit.InitNaiveBayes.ordinal()] = 3;
        } catch (NoSuchFieldError unused2) {
        }
        try {
            iArr2[BayesianInit.InitNoChange.ordinal()] = 1;
        } catch (NoSuchFieldError unused3) {
        }
        $SWITCH_TABLE$org$encogx$ml$bayesian$training$BayesianInit = iArr2;
        return iArr2;
    }

    static /* synthetic */ int[] $SWITCH_TABLE$org$encogx$ml$bayesian$training$TrainBayesian$Phase() {
        int[] iArr = $SWITCH_TABLE$org$encogx$ml$bayesian$training$TrainBayesian$Phase;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[Phase.valuesCustom().length];
        try {
            iArr2[Phase.Finish.ordinal()] = 5;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[Phase.Init.ordinal()] = 1;
        } catch (NoSuchFieldError unused2) {
        }
        try {
            iArr2[Phase.Probability.ordinal()] = 4;
        } catch (NoSuchFieldError unused3) {
        }
        try {
            iArr2[Phase.Search.ordinal()] = 2;
        } catch (NoSuchFieldError unused4) {
        }
        try {
            iArr2[Phase.SearchDone.ordinal()] = 3;
        } catch (NoSuchFieldError unused5) {
        }
        try {
            iArr2[Phase.Terminated.ordinal()] = 6;
        } catch (NoSuchFieldError unused6) {
        }
        $SWITCH_TABLE$org$encogx$ml$bayesian$training$TrainBayesian$Phase = iArr2;
        return iArr2;
    }
}
