package com.rapidminer.timeseriesanalysis.forecast.arima.utils;

import java.util.Arrays;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

/* loaded from: input_file:com/rapidminer/timeseriesanalysis/forecast/arima/utils/KalmanFilter.class */
public class KalmanFilter {
    private int m;
    private RealMatrix H;
    private RealVector H_vector;
    private RealMatrix Omega;
    private RealMatrix R;
    private RealVector stateVector_t;
    private RealVector stateVector_t_1;
    private RealMatrix S_t;
    private RealMatrix S_t_1;
    private double e_t;
    private double v_t;
    private double sigmaSquare;

    public KalmanFilter(int i, int i2, double[] dArr, double[] dArr2, double d, RealVector realVector, RealMatrix realMatrix) {
        this.m = Math.max(i, i2 + 1);
        this.stateVector_t = realVector;
        this.sigmaSquare = d;
        double[] dArr3 = new double[this.m];
        dArr3[0] = 1.0d;
        this.H_vector = MatrixUtils.createRealVector(dArr3);
        this.H = MatrixUtils.createRealMatrix(new double[][]{dArr3});
        double[][] dArr4 = new double[this.m][this.m];
        double[] dArr5 = new double[this.m];
        for (int i3 = 0; i3 < this.m; i3++) {
            if (i3 < i) {
                dArr4[i3][0] = dArr[i3];
            }
            if (i3 + 1 < this.m) {
                dArr4[i3][i3 + 1] = 1.0d;
            }
            if (i3 == 0) {
                dArr5[i3] = 1.0d;
            } else if (i3 - 1 < i2) {
                dArr5[i3] = -dArr2[i3 - 1];
            }
        }
        this.Omega = MatrixUtils.createRealMatrix(dArr4);
        RealVector createRealVector = MatrixUtils.createRealVector(dArr5);
        this.R = createRealVector.outerProduct(createRealVector.mapMultiply(d));
        double[] dArr6 = new double[this.R.getRowDimension() * this.R.getColumnDimension()];
        for (int i4 = 0; i4 < this.R.getColumnDimension(); i4++) {
            for (int i5 = 0; i5 < this.R.getRowDimension(); i5++) {
                dArr6[(i4 * this.R.getRowDimension()) + i5] = this.R.getEntry(i5, i4);
            }
        }
        try {
            RealVector operate = MatrixUtils.inverse(MatrixUtils.createRealIdentityMatrix(this.m * this.m).subtract(getKroneckerProduct(this.Omega, this.Omega))).operate(MatrixUtils.createRealVector(dArr6));
            double[][] dArr7 = new double[this.m][this.m];
            for (int i6 = 0; i6 < this.m; i6++) {
                for (int i7 = 0; i7 < this.m; i7++) {
                    dArr7[i7][i6] = operate.getEntry((i6 * this.m) + i7);
                }
            }
            this.S_t = MatrixUtils.createRealMatrix(dArr7);
        } catch (Exception e) {
            System.out.println(Arrays.toString(dArr) + "," + Arrays.toString(dArr2) + "," + d);
            throw e;
        }
    }

    public void calculateConditionalStateVector() {
        this.stateVector_t_1 = this.Omega.operate(this.stateVector_t);
    }

    public void calculateConditionalCovarianceMatrix() {
        this.S_t_1 = this.R.add(this.Omega.multiply(this.S_t.multiply(this.Omega.transpose())));
    }

    public void calculateResidual(double d) {
        this.e_t = d - this.H_vector.dotProduct(this.Omega.operate(this.stateVector_t));
    }

    public void calculateOneStepError() {
        this.v_t = this.H_vector.dotProduct(this.S_t_1.operate(this.H_vector)) / this.sigmaSquare;
    }

    public void updateStateVector() {
        this.stateVector_t = this.stateVector_t_1.add(this.S_t_1.operate(this.H_vector.mapMultiply(this.e_t / (this.sigmaSquare * this.v_t))));
    }

    public void updateCovarianceMatrix() {
        this.S_t = this.S_t_1.subtract(this.S_t_1.multiply(this.H.transpose().multiply(this.H.multiply(this.S_t_1.scalarMultiply(1.0d / (this.sigmaSquare * this.v_t))))));
    }

    public void calculateFilter(double d) {
        calculateConditionalStateVector();
        calculateConditionalCovarianceMatrix();
        calculateResidual(d);
        calculateOneStepError();
    }

    public void updateFilter() {
        updateStateVector();
        updateCovarianceMatrix();
    }

    public double getResidual() {
        return this.e_t;
    }

    public double getOneStepError() {
        return this.v_t;
    }

    private RealMatrix getKroneckerProduct(RealMatrix realMatrix, RealMatrix realMatrix2) {
        int rowDimension = realMatrix.getRowDimension();
        int columnDimension = realMatrix.getColumnDimension();
        int rowDimension2 = realMatrix2.getRowDimension();
        int columnDimension2 = realMatrix2.getColumnDimension();
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(rowDimension * rowDimension2, columnDimension * columnDimension2);
        for (int i = 0; i < rowDimension; i++) {
            for (int i2 = 0; i2 < columnDimension; i2++) {
                createRealMatrix.setSubMatrix(realMatrix2.scalarMultiply(realMatrix.getEntry(i, i2)).getData(), i * rowDimension2, i2 * columnDimension2);
            }
        }
        return createRealMatrix;
    }
}
