package org.encogx.ml.hmm.train.bw;

import java.util.EnumSet;
import java.util.Iterator;
import org.encogx.ml.data.MLDataPair;
import org.encogx.ml.data.MLDataSet;
import org.encogx.ml.data.MLSequenceSet;
import org.encogx.ml.hmm.HiddenMarkovModel;
import org.encogx.ml.hmm.alog.ForwardBackwardCalculator;
import org.encogx.ml.hmm.alog.ForwardBackwardScaledCalculator;

/* loaded from: input_file:org/encogx/ml/hmm/train/bw/TrainBaumWelchScaled.class */
public class TrainBaumWelchScaled extends BaseBaumWelch {
    public TrainBaumWelchScaled(HiddenMarkovModel hiddenMarkovModel, MLSequenceSet mLSequenceSet) {
        super(hiddenMarkovModel, mLSequenceSet);
    }

    @Override // org.encogx.ml.hmm.train.bw.BaseBaumWelch
    public double[][][] estimateXi(MLDataSet mLDataSet, ForwardBackwardCalculator forwardBackwardCalculator, HiddenMarkovModel hiddenMarkovModel) {
        if (mLDataSet.size() <= 1) {
            throw new IllegalArgumentException("Must have more than one observation");
        }
        double[][][] dArr = new double[mLDataSet.size() - 1][hiddenMarkovModel.getStateCount()][hiddenMarkovModel.getStateCount()];
        Iterator<MLDataPair> it = mLDataSet.iterator();
        it.next();
        for (int i = 0; i < mLDataSet.size() - 1; i++) {
            MLDataPair next = it.next();
            for (int i2 = 0; i2 < hiddenMarkovModel.getStateCount(); i2++) {
                for (int i3 = 0; i3 < hiddenMarkovModel.getStateCount(); i3++) {
                    dArr[i][i2][i3] = forwardBackwardCalculator.alphaElement(i, i2) * hiddenMarkovModel.getTransitionProbability(i2, i3) * hiddenMarkovModel.getStateDistribution(i3).probability(next) * forwardBackwardCalculator.betaElement(i + 1, i3);
                }
            }
        }
        return dArr;
    }

    @Override // org.encogx.ml.hmm.train.bw.BaseBaumWelch
    public ForwardBackwardCalculator generateForwardBackwardCalculator(MLDataSet mLDataSet, HiddenMarkovModel hiddenMarkovModel) {
        return new ForwardBackwardScaledCalculator(mLDataSet, hiddenMarkovModel, EnumSet.allOf(ForwardBackwardCalculator.Computation.class));
    }
}
