/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.extension.pythonscripting.operator;

import com.rapidminer.adaption.belt.ContextAdapter;
import com.rapidminer.adaption.belt.IOTable;
import com.rapidminer.belt.buffer.Buffers;
import com.rapidminer.belt.buffer.NumericBuffer;
import com.rapidminer.belt.column.Column;
import com.rapidminer.belt.execution.Context;
import com.rapidminer.belt.table.BeltConverter;
import com.rapidminer.belt.table.Builders;
import com.rapidminer.belt.table.Table;
import com.rapidminer.core.concurrency.ConcurrencyContext;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.extension.pythonscripting.definition.ConfigurationTools;
import com.rapidminer.extension.pythonscripting.operator.EnvironmentTools;
import com.rapidminer.extension.pythonscripting.operator.OperatorSentinel;
import com.rapidminer.extension.pythonscripting.operator.PythonForecast;
import com.rapidminer.extension.pythonscripting.operator.PythonForecastModel;
import com.rapidminer.extension.pythonscripting.operator.PythonOperator;
import com.rapidminer.extension.pythonscripting.operator.scripting.ScriptRunner;
import com.rapidminer.extension.pythonscripting.operator.scripting.python.PythonNativeObject;
import com.rapidminer.extension.pythonscripting.operator.scripting.python.PythonScriptingOperator;
import com.rapidminer.extension.timeseries.operator.TableTimeSeriesOperator;
import com.rapidminer.extension.timeseries.operator.helper.TableTimeSeriesHelper;
import com.rapidminer.extension.timeseries.operator.helper.TimeSeriesHelperBuilder;
import com.rapidminer.extension.timeseries.operator.helper.WrongConfiguredHelperException;
import com.rapidminer.gui.tools.VersionNumber;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.OperatorVersion;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.nio.file.BufferedFileObject;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.OutputPorts;
import com.rapidminer.operator.ports.metadata.MDTransformationRule;
import com.rapidminer.operator.ports.metadata.PassThroughRule;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.studio.concurrency.internal.SequentialConcurrencyContext;
import com.rapidminer.timeseriesanalysis.datamodel.dimension.interfaces.IndexDimension;
import com.rapidminer.timeseriesanalysis.datamodel.series.SeriesBuilder;
import com.rapidminer.timeseriesanalysis.datamodel.series.interfaces.ISeries;
import com.rapidminer.tools.LogService;
import java.io.IOException;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class PythonForecaster
extends PythonOperator {
    public static final String TYPE = "forecaster";
    public static final String INPUT_PORT_EXAMPLE_SET = "example set";
    public static final String OUTPUT_PORT_MODEL = "model";
    public static final String OUTPUT_PORT_ORIGINAL = "original";
    private static final String PARAMETER_INDEX_NAME = "index_name";
    private static final String PARAMETER_SERIES_NAME = "series_name";
    private static final String ENTRYPOINT = PythonForecaster.loadEntrypointScript();
    private static final String DECLARATION_TEMPLATE = PythonForecaster.loadDeclarationTemplate();
    private static final String SCRIPT_TEMPLATE = PythonForecaster.loadScriptTemplate();
    private final OutputPort modelOutput = (OutputPort)this.getOutputPorts().createPort("model");
    private final OutputPort dataOut = (OutputPort)this.getOutputPorts().createPort("original");
    private final Context context = ContextAdapter.adapt((ConcurrencyContext)new SequentialConcurrencyContext());
    private TableTimeSeriesHelper<Operator, ISeries<Double, Double>> tableTimeSeriesHelper;

    public PythonForecaster(OperatorDescription description) throws WrongConfiguredHelperException {
        this(description, true);
    }

    public PythonForecaster(OperatorDescription description, boolean reloadDeclaration) throws WrongConfiguredHelperException {
        super(description, true, PythonForecaster.getStaticInputPorts().size(), PythonForecaster.getStaticOutputPorts().size());
        this.tableTimeSeriesHelper = this.initializeTimeSeriesHelper(description);
        this.setupTransformer();
        if (reloadDeclaration) {
            this.tryReloadDeclaration();
        }
    }

    protected PythonForecaster(OperatorDescription description, String declaration, String definition) throws WrongConfiguredHelperException {
        this(description, false);
        this.setImmutable(declaration, definition);
    }

    public OperatorVersion[] getIncompatibleVersionChanges() {
        ArrayList<OperatorVersion> versions = new ArrayList<OperatorVersion>(Arrays.asList(super.getIncompatibleVersionChanges()));
        versions.add(PythonScriptingOperator.VERSION_HDF5_DATE_TIME_BUG);
        versions.add(PythonScriptingOperator.VERSION_ARROW_SERIALIZATION);
        return versions.toArray(new OperatorVersion[0]);
    }

    public String getIndicesParameter() throws UndefinedParameterError {
        return this.getParameterAsBoolean("has_indices") ? this.getParameterAsString("indices_attribute") : "";
    }

    public void doWork() throws OperatorException {
        try {
            this.reloadDeclaration();
        }
        catch (IOException e) {
            throw new UserError((Operator)this, (Throwable)e, "python_scripting.loading_json_failed");
        }
        this.tableTimeSeriesHelper.resetHelper();
        this.tableTimeSeriesHelper.readInputData(this.tableTimeSeriesHelper.getTableInputPort());
        IOTable data = this.tableTimeSeriesHelper.getInputIoTable();
        this.dataOut.deliver((IOObject)data);
        ISeries series = this.tableTimeSeriesHelper.getInputISeriesFromPort();
        Table valueTable = this.extractValueTable(data.getTable(), (ISeries<Double, Double>)series);
        Table indexTable = this.extractIndexTable(data.getTable(), (ISeries<Double, Double>)series);
        BufferedFileObject parameters = this.compileParameters((ISeries<Double, Double>)series);
        String script = this.getTrainingScript();
        List<IOObject> inputs = this.prepareScriptInputs(indexTable, valueTable, parameters);
        ScriptRunner runner = this.createScriptRunner(script);
        List<IOObject> results = this.executeScript(runner, inputs);
        this.processScriptResults(results, (ISeries<Double, Double>)series);
    }

    private Table extractValueTable(Table inputTable, ISeries<Double, Double> series) {
        List columns = inputTable.labels().stream().filter(c -> !c.equals(series.getIndexName()) && !c.equals(series.getSeriesNames()[0])).collect(Collectors.toList());
        columns.add(0, series.getSeriesNames()[0]);
        return inputTable.columns(columns);
    }

    private Table extractIndexTable(Table inputTable, ISeries<Double, Double> series) {
        if (this.getParameterAsBoolean("has_indices")) {
            return inputTable.columns(Collections.singletonList(series.getIndexName()));
        }
        return this.createIndexTable(series);
    }

    private Table createIndexTable(ISeries<Double, Double> series) {
        IndexDimension indexDimension = series.getIndexDimension();
        NumericBuffer indexBuffer = Buffers.realBuffer((int)series.getLength());
        for (int i = 0; i < series.getLength(); ++i) {
            indexBuffer.set(i, indexDimension.getIndexValueAsDouble(i));
        }
        Column indexColumn = indexBuffer.toColumn();
        return Builders.newTableBuilder((int)series.getLength()).add(indexDimension.getName(), indexColumn).build(this.context);
    }

    private BufferedFileObject compileParameters(ISeries<Double, Double> series) throws OperatorException {
        return new BufferedFileObject(this.compileParametersAsJson(new PythonOperator.ParameterKeyValue(this, PARAMETER_INDEX_NAME, series.getIndexName()), new PythonOperator.ParameterKeyValue(this, PARAMETER_SERIES_NAME, series.getSeriesNames()[0])));
    }

    private List<IOObject> prepareScriptInputs(Table indexTable, Table valueTable, BufferedFileObject parameters) throws UserError {
        ArrayList<IOObject> inputs = new ArrayList<IOObject>();
        inputs.add((IOObject)BeltConverter.convertSequentially((IOTable)new IOTable(indexTable)));
        inputs.add((IOObject)BeltConverter.convertSequentially((IOTable)new IOTable(valueTable)));
        List<IOObject> dynamicInputs = this.checkInputTypes(this.createScriptRunner(this.getTrainingScript())).stream().skip(PythonForecaster.getStaticInputPorts().size()).collect(Collectors.toList());
        this.checkExampleSet(dynamicInputs);
        inputs.addAll(dynamicInputs);
        inputs.add((IOObject)parameters);
        return inputs;
    }

    private ScriptRunner createScriptRunner(String script) throws UserError {
        boolean useArrowSerialization = this.getCompatibilityLevel().isAbove((VersionNumber)PythonScriptingOperator.VERSION_ARROW_SERIALIZATION);
        ScriptRunner runner = EnvironmentTools.getScriptRunner(this, script, null, false, useArrowSerialization);
        runner.registerLogger(LogService.getRoot());
        return runner;
    }

    private List<IOObject> executeScript(ScriptRunner runner, List<IOObject> inputs) throws OperatorException {
        OutputPorts outputs = this.getOutputPorts();
        int nOutputs = outputs.getNumberOfPorts();
        OperatorSentinel.Sentinel ignored = OperatorSentinel.scheduleSentinel(this, runner::cancel);
        try {
            List<IOObject> list = runner.run(inputs, nOutputs);
            if (ignored != null) {
                ignored.close();
            }
            return list;
        }
        catch (Throwable throwable) {
            try {
                if (ignored != null) {
                    try {
                        ignored.close();
                    }
                    catch (Throwable throwable2) {
                        throwable.addSuppressed(throwable2);
                    }
                }
                throw throwable;
            }
            catch (Exception e) {
                this.checkForStop();
                throw new UserError((Operator)this, (Throwable)e, "python_scripting.python_runner_error");
            }
        }
    }

    private void processScriptResults(List<IOObject> results, ISeries<Double, Double> series) throws OperatorException {
        if (results.size() != this.getOutputPorts().getNumberOfPorts()) {
            throw new UserError((Operator)this, "python_scripting.mismatching_outputs", new Object[]{this.getOutputPorts().getNumberOfPorts(), results.size()});
        }
        try {
            PythonNativeObject model = (PythonNativeObject)results.get(0);
            String descriptionText = this.extractDescriptionText(results.get(1));
            byte[] modelBytes = Files.readAllBytes(model.getFile().toPath());
            PythonForecast forecast = new PythonForecast(this.getDeclaration().getName(), descriptionText, this.getScript(), modelBytes, this);
            this.setWrapper(forecast);
            PythonForecastModel forecastModel = new PythonForecastModel(forecast, series, this.getIndicesParameter());
            this.modelOutput.deliver((IOObject)forecastModel);
        }
        catch (IOException | ClassCastException e) {
            throw new UserError((Operator)this, (Throwable)e, "python_scripting.invalid_training_result_type");
        }
        this.deliverDynamicOutputs(results);
    }

    private String extractDescriptionText(IOObject descriptionObj) {
        if (descriptionObj instanceof ExampleSet) {
            ExampleSet exampleSet = (ExampleSet)descriptionObj;
            return exampleSet.getAttributes().get("description").getMapping().mapIndex(0);
        }
        if (descriptionObj instanceof IOTable) {
            IOTable ioTable = (IOTable)descriptionObj;
            return ioTable.getTable().column("description").getDictionary().get(1);
        }
        throw new IllegalArgumentException("Unsupported type for description: " + descriptionObj.getClass());
    }

    private void deliverDynamicOutputs(List<IOObject> results) {
        OutputPorts outputs = this.getOutputPorts();
        for (int i = PythonForecaster.getStaticOutputPorts().size(); i < outputs.getNumberOfPorts(); ++i) {
            OutputPort port = (OutputPort)outputs.getPortByIndex(i);
            port.deliver(results.get(i));
        }
    }

    @Override
    public List<ParameterType> getParameterTypes() {
        List types = this.tableTimeSeriesHelper.getParameterTypes(new ArrayList());
        types.addAll(super.getParameterTypes());
        return types;
    }

    @Override
    public String getDefaultDeclaration() {
        return DECLARATION_TEMPLATE;
    }

    @Override
    public String getDefaultScript() {
        return SCRIPT_TEMPLATE;
    }

    @Override
    public String getType() {
        return TYPE;
    }

    private String getTrainingScript() throws UndefinedParameterError {
        return this.getScript() + ENTRYPOINT;
    }

    private TableTimeSeriesHelper<Operator, ISeries<Double, Double>> initializeTimeSeriesHelper(OperatorDescription description) throws WrongConfiguredHelperException {
        class PythonForecasterTableTimeSeriesHelper
        extends TableTimeSeriesOperator<ISeries<Double, Double>> {
            PythonForecasterTableTimeSeriesHelper(OperatorDescription description) throws WrongConfiguredHelperException {
                super(description);
            }

            public TableTimeSeriesHelper<Operator, ISeries<Double, Double>> getTableTimeSeriesHelper() {
                return this.tableTimeSeriesHelper;
            }

            protected TableTimeSeriesHelper<Operator, ISeries<Double, Double>> initTableTimeSeriesOperator() throws WrongConfiguredHelperException {
                TimeSeriesHelperBuilder builder = new TimeSeriesHelperBuilder((Operator)PythonForecaster.this);
                return builder.asInputPortOperator(PythonForecaster.INPUT_PORT_EXAMPLE_SET).setIndiceHandling(this.getIndicesHandling()).setValuesType(this.getAllowedValuesType()).build();
            }

            protected TableTimeSeriesHelper.IndicesHandling getIndicesHandling() {
                return TableTimeSeriesHelper.IndicesHandling.OPTIONAL_INDICES;
            }

            protected SeriesBuilder.ValuesType getAllowedValuesType() {
                return SeriesBuilder.ValuesType.MIXED;
            }
        }
        return new PythonForecasterTableTimeSeriesHelper(description).getTableTimeSeriesHelper();
    }

    private void setupTransformer() throws WrongConfiguredHelperException {
        InputPort exampleSetInputPort = this.tableTimeSeriesHelper.getTableInputPort();
        this.tableTimeSeriesHelper.addEmptyInputCheck();
        this.getTransformer().addGenerationRule(this.modelOutput, PythonForecastModel.class);
        this.getTransformer().addRule((MDTransformationRule)new PassThroughRule(exampleSetInputPort, this.dataOut, false));
    }

    private static String loadEntrypointScript() {
        return ConfigurationTools.loadTextFile("scripts/forecaster_train_entrypoint.py");
    }

    private static String loadDeclarationTemplate() {
        return ConfigurationTools.loadTextFile("scripts/forecaster_declaration_template.json");
    }

    private static String loadScriptTemplate() {
        return ConfigurationTools.loadTextFile("scripts/forecaster_definition_template.py");
    }

    public static Set<String> getStaticParameters() {
        HashSet<String> staticParameters = new HashSet<String>(PythonOperator.getStaticParameters());
        staticParameters.addAll(Arrays.asList("time_series_attribute", "has_indices", "indices_attribute", PARAMETER_INDEX_NAME, PARAMETER_SERIES_NAME, "sort_time_series"));
        return staticParameters;
    }

    public static Set<String> getStaticInputPorts() {
        return new HashSet<String>(List.of(INPUT_PORT_EXAMPLE_SET));
    }

    public static Set<String> getStaticOutputPorts() {
        return new HashSet<String>(Arrays.asList(OUTPUT_PORT_MODEL, OUTPUT_PORT_ORIGINAL));
    }
}

