package com.rapidminer.timeseriesanalysis.forecast.arima;

import com.github.lbfgs4j.LbfgsMinimizer;
import com.github.lbfgs4j.liblbfgs.Function;
import com.rapidminer.timeseriesanalysis.datamodel.TimeSeries;
import com.rapidminer.timeseriesanalysis.datamodel.ValueSeries;
import com.rapidminer.timeseriesanalysis.forecast.TimeSeriesForecast;
import com.rapidminer.timeseriesanalysis.forecast.TimeSeriesForecastTrainer;
import com.rapidminer.timeseriesanalysis.forecast.ValueSeriesForecast;
import com.rapidminer.timeseriesanalysis.forecast.ValueSeriesForecastTrainer;
import com.rapidminer.timeseriesanalysis.forecast.arima.utils.ArimaUtils;
import com.rapidminer.timeseriesanalysis.forecast.arima.utils.ArmaLogLikelihood;
import com.rapidminer.timeseriesanalysis.forecast.arima.utils.HannanRissanen;
import com.rapidminer.timeseriesanalysis.forecast.arima.utils.YuleWalker;
import com.rapidminer.timeseriesanalysis.forecast.modelevaluation.AkaikesInformationCriterion;
import com.rapidminer.timeseriesanalysis.forecast.modelevaluation.BayesianInformationCriterion;
import com.rapidminer.timeseriesanalysis.forecast.modelevaluation.CorrectedAkaikesInformationCriterion;
import com.rapidminer.timeseriesanalysis.methods.transformation.Differentiation;
import com.rapidminer.timeseriesanalysis.tools.SeriesUtils;
import java.security.InvalidParameterException;
import java.util.Arrays;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.SimpleBounds;
import org.apache.commons.math3.optim.SimplePointChecker;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateFunctionMappingAdapter;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.PowellOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.util.Pair;

/* loaded from: input_file:com/rapidminer/timeseriesanalysis/forecast/arima/ArimaTrainer.class */
public class ArimaTrainer implements TimeSeriesForecastTrainer, ValueSeriesForecastTrainer {
    private int p;
    private int d;
    private int q;
    private boolean estimateConstant;
    private boolean transformParams;
    private ArimaUtils.TrainingAlgorithm trainingAlgorithm;
    private ArimaUtils.OptimizationMethod optimizationMethod;
    private int maxNumberOfIterations;
    private boolean useRegressionForBOBYQAParameters;
    private double[] parametersForOptimization;
    private double[] initialParameters;
    private boolean calculateStartParameters;
    private double[] finalParameters;
    private double finalLogLikelihood;
    private double finalAicValue;
    private double finalBicValue;
    private double finalCorrectedAicValue;

    private ArimaTrainer(int i, int i2, int i3, boolean z, boolean z2, ArimaUtils.TrainingAlgorithm trainingAlgorithm, ArimaUtils.OptimizationMethod optimizationMethod, int i4, boolean z3) {
        if (i < 0) {
            throw new InvalidParameterException("p has to be positive or 0. p: " + i);
        }
        if (i2 < 0) {
            throw new InvalidParameterException("d has to be positive or 0. d: " + i2);
        }
        if (i3 < 0) {
            throw new InvalidParameterException("q has to be positive or 0. q: " + i3);
        }
        if (i4 <= 0) {
            throw new InvalidParameterException("maxNumberOfIterations has to be positive. maxNumberOfIterations: " + i4);
        }
        if (i + i3 <= 0) {
            throw new InvalidParameterException("At least one AR or MA term has to provided.");
        }
        this.p = i;
        this.d = i2;
        this.q = i3;
        this.estimateConstant = z;
        this.transformParams = z2;
        this.trainingAlgorithm = trainingAlgorithm;
        this.maxNumberOfIterations = i4;
        this.optimizationMethod = optimizationMethod;
        this.useRegressionForBOBYQAParameters = z3;
        checkAndSetOptimizationParameters(getDefaultOptimizationParameters(optimizationMethod));
        this.calculateStartParameters = true;
        this.initialParameters = null;
    }

    public static ArimaTrainer create(int i, int i2, int i3) {
        return new ArimaTrainer(i, i2, i3, true, true, ArimaUtils.TrainingAlgorithm.CONDITIONAL_MAX_LOGLIKELIHOOD, ArimaUtils.OptimizationMethod.LBFGS, 1000, false);
    }

    public static ArimaTrainer create(int i, int i2, int i3, boolean z, Boolean bool, ArimaUtils.TrainingAlgorithm trainingAlgorithm, ArimaUtils.OptimizationMethod optimizationMethod, int i4, boolean z2) {
        return new ArimaTrainer(i, i2, i3, z, bool.booleanValue(), trainingAlgorithm, optimizationMethod, i4, z2);
    }

    @Override // com.rapidminer.timeseriesanalysis.forecast.ValueSeriesForecastTrainer
    public ValueSeriesForecast trainForecast(ValueSeries valueSeries) {
        if (valueSeries != null) {
            return trainArima(valueSeries.getValues());
        }
        throw new InvalidParameterException("The provided valueSeries object is null.");
    }

    @Override // com.rapidminer.timeseriesanalysis.forecast.TimeSeriesForecastTrainer
    public TimeSeriesForecast trainForecast(TimeSeries timeSeries) {
        if (timeSeries != null) {
            return trainArima(timeSeries.getValues());
        }
        throw new InvalidParameterException("The provided timeSeries object is null.");
    }

    public Arima trainArima(double[] dArr) {
        validateAllowedParameters(dArr.length);
        if (SeriesUtils.hasNaNValues(dArr)) {
            throw new RuntimeException("The provided Series contains missing Values, which are not supported for the application of ARIMA.");
        }
        if (SeriesUtils.hasInfiniteValues(dArr)) {
            throw new RuntimeException("The provided Series contains infinite Values, which are not supported for the application of ARIMA.");
        }
        double[] diff = Differentiation.diff(dArr, this.d);
        if (this.trainingAlgorithm == ArimaUtils.TrainingAlgorithm.HANNAN_RISSANEN && this.d == 0) {
            HannanRissanen hannanRissanen = new HannanRissanen(this.p, this.q, this.estimateConstant, diff, this.maxNumberOfIterations, 15);
            Arima trainArima = hannanRissanen.trainArima();
            this.finalParameters = hannanRissanen.getParameters(false);
            return trainArima;
        }
        if (this.trainingAlgorithm == ArimaUtils.TrainingAlgorithm.HANNAN_RISSANEN && this.d != 0) {
            throw new RuntimeException("Hannan-Rissanen algorithm for an ARIMA process with d != 0 is not supported.");
        }
        if (this.calculateStartParameters) {
            if (this.p == 0 && this.q > 0) {
                this.initialParameters = YuleWalker.create(this.q, diff, this.estimateConstant).computeCoefficients();
            } else if (this.p <= 0 || this.q != 0) {
                HannanRissanen hannanRissanen2 = new HannanRissanen(this.p, this.q, this.estimateConstant, diff);
                hannanRissanen2.trainArima();
                if (this.trainingAlgorithm == ArimaUtils.TrainingAlgorithm.EXACT_MAX_LOGLIKELIHOOD) {
                    this.initialParameters = hannanRissanen2.getParameters(true);
                } else {
                    this.initialParameters = hannanRissanen2.getParameters(false);
                }
            } else {
                this.initialParameters = YuleWalker.create(this.p, diff, this.estimateConstant).computeCoefficients();
            }
        }
        if (this.trainingAlgorithm == ArimaUtils.TrainingAlgorithm.CONDITIONAL_MAX_LOGLIKELIHOOD || this.trainingAlgorithm == ArimaUtils.TrainingAlgorithm.CONDITIONAL_THEN_EXACT_MAX_LOGLIKELIHOOD) {
            ArmaLogLikelihood armaLogLikelihood = new ArmaLogLikelihood(this.p, this.q, this.estimateConstant, this.transformParams, ArimaUtils.ArimaLogLikelihoodType.CONDITIONAL, diff);
            double[] dArr2 = this.initialParameters;
            if (this.transformParams) {
                dArr2 = ArimaUtils.inverseTransformParams(dArr2, this.p, this.q, this.estimateConstant);
            }
            double[] performOptimization = performOptimization(diff, armaLogLikelihood, false, dArr2);
            if (this.transformParams) {
                performOptimization = ArimaUtils.transformParams(performOptimization, this.p, this.q, this.estimateConstant);
            }
            if (this.trainingAlgorithm == ArimaUtils.TrainingAlgorithm.CONDITIONAL_MAX_LOGLIKELIHOOD) {
                armaLogLikelihood.setTransParams(false);
                return prepareFinalArima(performOptimization, armaLogLikelihood, diff.length);
            }
            if (this.trainingAlgorithm == ArimaUtils.TrainingAlgorithm.CONDITIONAL_THEN_EXACT_MAX_LOGLIKELIHOOD) {
                this.initialParameters = new double[performOptimization.length + 1];
                int i = 0;
                while (i < performOptimization.length) {
                    this.initialParameters[i] = performOptimization[i];
                    i++;
                }
                this.initialParameters[i] = armaLogLikelihood.getSigmaSquare();
            }
        }
        if (this.trainingAlgorithm != ArimaUtils.TrainingAlgorithm.EXACT_MAX_LOGLIKELIHOOD && this.trainingAlgorithm != ArimaUtils.TrainingAlgorithm.CONDITIONAL_THEN_EXACT_MAX_LOGLIKELIHOOD) {
            throw new RuntimeException("Provided TrainingAlgorithm is not one of: " + Arrays.toString(ArimaUtils.TrainingAlgorithm.values()));
        }
        ArmaLogLikelihood armaLogLikelihood2 = new ArmaLogLikelihood(this.p, this.q, this.estimateConstant, this.transformParams, ArimaUtils.ArimaLogLikelihoodType.EXACT, diff);
        double[] dArr3 = this.initialParameters;
        if (this.transformParams) {
            dArr3 = ArimaUtils.inverseTransformParams(dArr3, this.p, this.q, this.estimateConstant);
        }
        double[] performOptimization2 = performOptimization(diff, armaLogLikelihood2, true, dArr3);
        if (this.transformParams) {
            performOptimization2 = ArimaUtils.transformParams(performOptimization2, this.p, this.q, this.estimateConstant);
        }
        armaLogLikelihood2.setTransParams(false);
        return prepareFinalArima(performOptimization2, armaLogLikelihood2, diff.length);
    }

    public Pair<Boolean, String> validateNumberOfParameters(int i) {
        int i2 = this.p + this.d + this.q;
        if (this.estimateConstant) {
            i2++;
        }
        if (this.trainingAlgorithm == ArimaUtils.TrainingAlgorithm.EXACT_MAX_LOGLIKELIHOOD) {
            i2++;
        }
        return i2 > i - 3 ? new Pair<>(false, "The number of parameters exceeds the allowed number for this Series. Number of parameters: " + i2 + ",\t allowed number of parameters (series.length - 3): " + (i - 3)) : new Pair<>(true, "");
    }

    public Pair<Boolean, String> validateAllowedParametersForHannanRissanen(int i) {
        if (this.calculateStartParameters) {
            int round = (int) Math.round(12.0d * Math.pow(i / 100.0f, 0.25d));
            int max = Math.max(round + this.q, this.p);
            if (max >= i - this.d) {
                return new Pair<>(false, "The given parameters p,d,q are not valid for the given series to apply HannanRissanen (HR) to estimate Start Parameters.\nthe condition ystart &lt; (length - d) is not fulfilled:\nlength: " + i + ",\t d: " + this.d + ",\t q: " + this.q + ",\t p: " + this.p + "\n ystart = Math.max(maxOrderOfInitialARProcess + q, p) = " + max + "\t, maxOrderOfInitialARProcess of HR =  " + round);
            }
        }
        return new Pair<>(true, "");
    }

    public void validateAllowedParameters(int i) {
        Pair<Boolean, String> validateNumberOfParameters = validateNumberOfParameters(i);
        if (!((Boolean) validateNumberOfParameters.getFirst()).booleanValue()) {
            throw new InvalidParameterException((String) validateNumberOfParameters.getSecond());
        }
        Pair<Boolean, String> validateAllowedParametersForHannanRissanen = validateAllowedParametersForHannanRissanen(i);
        if (!((Boolean) validateAllowedParametersForHannanRissanen.getFirst()).booleanValue()) {
            throw new InvalidParameterException((String) validateAllowedParametersForHannanRissanen.getSecond());
        }
    }

    private Arima prepareFinalArima(double[] dArr, ArmaLogLikelihood armaLogLikelihood, int i) {
        AkaikesInformationCriterion akaikesInformationCriterion = new AkaikesInformationCriterion();
        BayesianInformationCriterion bayesianInformationCriterion = new BayesianInformationCriterion();
        CorrectedAkaikesInformationCriterion correctedAkaikesInformationCriterion = new CorrectedAkaikesInformationCriterion();
        this.finalParameters = dArr;
        this.finalLogLikelihood = armaLogLikelihood.value(dArr);
        double[] arCoefficients = armaLogLikelihood.getArCoefficients();
        double[] maCoefficients = armaLogLikelihood.getMaCoefficients();
        double constant = this.estimateConstant ? armaLogLikelihood.getConstant() : 0.0d;
        double[] residuals = armaLogLikelihood.getResiduals();
        double[] dArr2 = new double[this.q];
        for (int i2 = 0; i2 < this.q; i2++) {
            dArr2[i2] = residuals[(residuals.length - this.q) + i2];
        }
        this.finalAicValue = akaikesInformationCriterion.compute(this.finalLogLikelihood, armaLogLikelihood.getNumberOfParameters(), i);
        this.finalBicValue = bayesianInformationCriterion.compute(this.finalLogLikelihood, armaLogLikelihood.getNumberOfParameters(), i);
        this.finalCorrectedAicValue = correctedAkaikesInformationCriterion.compute(this.finalLogLikelihood, armaLogLikelihood.getNumberOfParameters(), i);
        return Arima.create(this.p, this.d, this.q, arCoefficients, maCoefficients, constant, dArr2);
    }

    private double[] performOptimization(double[] dArr, MultivariateFunction multivariateFunction, boolean z, double[] dArr2) {
        OptimizationData initialGuess = new InitialGuess(dArr2);
        OptimizationData createSimpleBounds = createSimpleBounds(z);
        if (this.parametersForOptimization == null) {
            throw new RuntimeException("useDefaultParametersForOptimization is false, but no optimizationParameters array was provided");
        }
        switch (this.optimizationMethod) {
            case BOBYQA:
                double d = this.parametersForOptimization[0];
                double d2 = this.parametersForOptimization[1];
                double d3 = this.parametersForOptimization[2];
                if (this.useRegressionForBOBYQAParameters) {
                    d = calculateParameterForBOBYQAbyRegression(dArr.length, this.p, this.q, 0);
                    d2 = calculateParameterForBOBYQAbyRegression(dArr.length, this.p, this.q, 1);
                    d3 = calculateParameterForBOBYQAbyRegression(dArr.length, this.p, this.q, 2);
                }
                return new BOBYQAOptimizer((int) (d * (this.p + this.q + 3)), d2, d3).optimize(new OptimizationData[]{new MaxEval(this.maxNumberOfIterations), GoalType.MAXIMIZE, new ObjectiveFunction(multivariateFunction), initialGuess, createSimpleBounds}).getPoint();
            case CMAES:
                double d4 = this.parametersForOptimization[0];
                double d5 = this.parametersForOptimization[1];
                double d6 = this.parametersForOptimization[2];
                double d7 = this.parametersForOptimization[3];
                double d8 = this.parametersForOptimization[4];
                double d9 = this.parametersForOptimization[5];
                double d10 = this.parametersForOptimization[6];
                double d11 = this.parametersForOptimization[7];
                double[] dArr3 = new double[this.p + this.q + 1];
                for (int i = 0; i < dArr3.length; i++) {
                    dArr3[i] = d4;
                }
                return new CMAESOptimizer(this.maxNumberOfIterations, d7, true, (int) d8, (int) d9, new MersenneTwister(), true, new SimplePointChecker(d10, d11)).optimize(new OptimizationData[]{new MaxEval(this.maxNumberOfIterations), GoalType.MAXIMIZE, new ObjectiveFunction(multivariateFunction), initialGuess, createSimpleBounds, new CMAESOptimizer.Sigma(dArr3), new CMAESOptimizer.PopulationSize(Math.round((float) (d5 + (d6 * Math.log(this.p + this.q + 1)))))}).getPoint();
            case NELDERMEAD:
                return new SimplexOptimizer(new SimplePointChecker(this.parametersForOptimization[0], this.parametersForOptimization[1])).optimize(new OptimizationData[]{new MaxEval(this.maxNumberOfIterations), GoalType.MAXIMIZE, new ObjectiveFunction(multivariateFunction), initialGuess, new NelderMeadSimplex(this.p + this.q + 1)}).getPoint();
            case POWELL:
                double d12 = this.parametersForOptimization[0];
                double d13 = this.parametersForOptimization[1];
                MultivariateFunctionMappingAdapter multivariateFunctionMappingAdapter = new MultivariateFunctionMappingAdapter(multivariateFunction, createSimpleBounds.getLower(), createSimpleBounds.getUpper());
                return multivariateFunctionMappingAdapter.unboundedToBounded(new PowellOptimizer(d12, d13).optimize(new OptimizationData[]{new MaxEval(this.maxNumberOfIterations), GoalType.MAXIMIZE, new ObjectiveFunction(multivariateFunctionMappingAdapter), initialGuess}).getPoint());
            case LBFGS:
                return new LbfgsMinimizer(false).minimize((Function) multivariateFunction, dArr2);
            default:
                throw new RuntimeException("Provided OptimizationMethod is not one of: " + Arrays.toString(ArimaUtils.OptimizationMethod.values()));
        }
    }

    private SimpleBounds createSimpleBounds(boolean z) {
        int i = this.p + this.q;
        if (this.estimateConstant) {
            i++;
        }
        if (z) {
            i++;
        }
        double[][] dArr = new double[2][i];
        int i2 = 0;
        while (i2 < this.p + this.q) {
            dArr[0][i2] = (-10.0d) + Math.ulp(-1.0d);
            dArr[1][i2] = 10.0d - Math.ulp(1.0d);
            i2++;
        }
        if (this.estimateConstant) {
            dArr[0][i2] = Double.NEGATIVE_INFINITY;
            dArr[1][i2] = Double.POSITIVE_INFINITY;
            i2++;
        }
        if (z) {
            dArr[0][i2] = 0.0d + Math.ulp(0.0d);
            dArr[1][i2] = Double.POSITIVE_INFINITY;
        }
        return new SimpleBounds(dArr[0], dArr[1]);
    }

    private double calculateParameterForBOBYQAbyRegression(int i, int i2, int i3, int i4) {
        double[][] dArr = {new double[]{1.2360566632877907d, 9.228543571328063E-7d, 0.015222052447694447d, 0.01020758884697153d}, new double[]{0.22780848175157706d, -7.553414912441716E-6d, -0.013150511834201704d, -0.014741139928110577d}, new double[]{0.0880210759444603d, -8.51594906210044E-6d, -0.007564360060476942d, -0.0022898447216298965d}};
        return dArr[i4][0] + (dArr[i4][1] * i) + (dArr[i4][2] * i2) + (dArr[i4][3] * i3);
    }

    private static double[] getDefaultOptimizationParameters(ArimaUtils.OptimizationMethod optimizationMethod) {
        double[] dArr = null;
        switch (optimizationMethod) {
            case BOBYQA:
                dArr = new double[]{1.5d, 0.01d, 1.0E-4d};
                break;
            case CMAES:
                dArr = new double[]{0.1d, 500.0d, 3.0d, -1000.0d, 1000.0d, 0.0d, 0.01d, -1.0d};
                break;
            case NELDERMEAD:
                dArr = new double[]{1.0E-4d, -1.0d};
                break;
            case POWELL:
                dArr = new double[]{1.0E-8d, 1.0E-9d};
                break;
            case LBFGS:
                dArr = new double[0];
                break;
        }
        return dArr;
    }

    public double[] getFinalParameters() {
        return this.finalParameters;
    }

    public double getFinalLogLikelihood() {
        return this.finalLogLikelihood;
    }

    public double getFinalAicValue() {
        return this.finalAicValue;
    }

    public double getFinalBicValue() {
        return this.finalBicValue;
    }

    public double getFinalCorrectedAicValue() {
        return this.finalCorrectedAicValue;
    }

    public void setOptimizationParameters(double[] dArr) {
        if (this.trainingAlgorithm == ArimaUtils.TrainingAlgorithm.HANNAN_RISSANEN) {
            throw new RuntimeException("HANNAN_RISSANEN as the training algorithm does not need optimization parameters.");
        }
        checkAndSetOptimizationParameters(dArr);
    }

    private void checkAndSetOptimizationParameters(double[] dArr) {
        int i = 0;
        switch (this.optimizationMethod) {
            case BOBYQA:
                i = 3;
                break;
            case CMAES:
                i = 8;
                break;
            case NELDERMEAD:
                i = 2;
                break;
            case POWELL:
                i = 2;
                break;
            case LBFGS:
                i = 0;
                break;
        }
        if (dArr.length != i) {
            throw new InvalidParameterException("Length of provided optimization parameters array (" + dArr.length + ") is not equal to the necessary length (" + i + ") for the used optimization method (" + this.optimizationMethod.toString() + ").");
        }
        this.parametersForOptimization = dArr;
    }

    public void setInitialParameters(double[] dArr) {
        if (this.trainingAlgorithm == ArimaUtils.TrainingAlgorithm.HANNAN_RISSANEN) {
            throw new RuntimeException("HANNAN_RISSANEN as the training algorithm does not need initial parameters.");
        }
        int i = this.p + this.q;
        String str = "[ p arCoefficients, q maCoefficients";
        if (this.estimateConstant) {
            i++;
            str = str + ", constant";
        }
        if (this.trainingAlgorithm == ArimaUtils.TrainingAlgorithm.EXACT_MAX_LOGLIKELIHOOD) {
            i++;
            str = str + ", sigmaSquare";
        }
        if (dArr.length != i) {
            throw new InvalidParameterException("Length of provided initialParametersArray (" + dArr.length + ") is not equal to the necessary length (" + i + "). Please provide the following parameters as initial Parameters: " + str + " ].");
        }
        this.initialParameters = dArr;
        this.calculateStartParameters = false;
    }
}
