package com.rapidminer.extension.pythonscripting.operator;

import com.rapidminer.adaption.belt.ContextAdapter;
import com.rapidminer.adaption.belt.IOTable;
import com.rapidminer.belt.buffer.Buffers;
import com.rapidminer.belt.buffer.NumericBuffer;
import com.rapidminer.belt.execution.Context;
import com.rapidminer.belt.table.BeltConverter;
import com.rapidminer.belt.table.Builders;
import com.rapidminer.belt.table.Table;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.extension.pythonscripting.definition.ConfigurationTools;
import com.rapidminer.extension.pythonscripting.operator.OperatorSentinel;
import com.rapidminer.extension.pythonscripting.operator.PythonOperator;
import com.rapidminer.extension.pythonscripting.operator.scripting.ScriptRunner;
import com.rapidminer.extension.pythonscripting.operator.scripting.python.PythonNativeObject;
import com.rapidminer.extension.timeseries.operator.TableTimeSeriesOperator;
import com.rapidminer.extension.timeseries.operator.helper.TableTimeSeriesHelper;
import com.rapidminer.extension.timeseries.operator.helper.TimeSeriesHelperBuilder;
import com.rapidminer.extension.timeseries.operator.helper.WrongConfiguredHelperException;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.nio.file.BufferedFileObject;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.OutputPorts;
import com.rapidminer.operator.ports.metadata.PassThroughRule;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.studio.concurrency.internal.SequentialConcurrencyContext;
import com.rapidminer.timeseriesanalysis.datamodel.dimension.interfaces.IndexDimension;
import com.rapidminer.timeseriesanalysis.datamodel.series.SeriesBuilder;
import com.rapidminer.timeseriesanalysis.datamodel.series.interfaces.ISeries;
import com.rapidminer.tools.LogService;
import java.io.IOException;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:com/rapidminer/extension/pythonscripting/operator/PythonForecaster.class */
public class PythonForecaster extends PythonOperator {
    public static final String TYPE = "forecaster";
    public static final String ENTRYPOINT;
    public static final String INPUT_PORT_EXAMPLE_SET = "example set";
    public static final String OUTPUT_PORT_MODEL = "model";
    public static final String OUTPUT_PORT_ORIGINAL = "original";
    private static final String PARAMETER_INDEX_NAME = "index_name";
    private static final String PARAMETER_SERIES_NAME = "series_name";
    private static final Set<String> STATIC_INPUTPORTS = new HashSet();
    private static final Set<String> STATIC_OUTPUTPORTS = new HashSet();
    private static final Set<String> STATIC_PARAMETERS = new HashSet(PythonOperator.getStaticParameters());
    private static final String DECLARATION_TEMPLATE;
    private static final String SCRIPT_TEMPLATE;
    private static final int NUMBER_OF_STATIC_INPUTPORTS;
    private static final int NUMBER_OF_STATIC_OUTPUTPORTS;
    private final OutputPort modelOutput;
    private final OutputPort dataOut;
    private final Context context;
    private TableTimeSeriesHelper<Operator, ISeries<Double, Double>> tableTimeSeriesHelper;

    public PythonForecaster(OperatorDescription operatorDescription) throws WrongConfiguredHelperException {
        this(operatorDescription, true);
    }

    /* JADX WARN: Type inference failed for: r1v11, types: [com.rapidminer.extension.pythonscripting.operator.PythonForecaster$1PythonForecasterTableTimeSeriesHelper] */
    public PythonForecaster(OperatorDescription operatorDescription, boolean z) throws WrongConfiguredHelperException {
        super(operatorDescription, true, NUMBER_OF_STATIC_INPUTPORTS, NUMBER_OF_STATIC_OUTPUTPORTS);
        this.modelOutput = getOutputPorts().createPort(OUTPUT_PORT_MODEL);
        this.dataOut = getOutputPorts().createPort(OUTPUT_PORT_ORIGINAL);
        this.context = ContextAdapter.adapt(new SequentialConcurrencyContext());
        this.tableTimeSeriesHelper = new TableTimeSeriesOperator<ISeries<Double, Double>>(operatorDescription) { // from class: com.rapidminer.extension.pythonscripting.operator.PythonForecaster.1PythonForecasterTableTimeSeriesHelper
            public TableTimeSeriesHelper<Operator, ISeries<Double, Double>> getTableTimeSeriesHelper() {
                return this.tableTimeSeriesHelper;
            }

            protected TableTimeSeriesHelper<Operator, ISeries<Double, Double>> initTableTimeSeriesOperator() throws WrongConfiguredHelperException {
                return new TimeSeriesHelperBuilder(PythonForecaster.this).asInputPortOperator(PythonForecaster.INPUT_PORT_EXAMPLE_SET).setIndiceHandling(getIndicesHandling()).setValuesType(getAllowedValuesType()).build();
            }

            protected TableTimeSeriesHelper.IndicesHandling getIndicesHandling() {
                return TableTimeSeriesHelper.IndicesHandling.OPTIONAL_INDICES;
            }

            protected SeriesBuilder.ValuesType getAllowedValuesType() {
                return SeriesBuilder.ValuesType.MIXED;
            }
        }.getTableTimeSeriesHelper();
        InputPort tableInputPort = this.tableTimeSeriesHelper.getTableInputPort();
        this.tableTimeSeriesHelper.addEmptyInputCheck();
        getTransformer().addGenerationRule(this.modelOutput, PythonForecastModel.class);
        getTransformer().addRule(new PassThroughRule(tableInputPort, this.dataOut, false));
        if (z) {
            tryReloadDeclaration();
        }
    }

    protected PythonForecaster(OperatorDescription operatorDescription, String str, String str2) throws WrongConfiguredHelperException {
        this(operatorDescription, false);
        setImmutable(str, str2);
    }

    public String getIndicesParameter() throws UndefinedParameterError {
        return getParameterAsBoolean("has_indices") ? getParameterAsString("indices_attribute") : "";
    }

    public static Set<String> getStaticParameters() {
        return STATIC_PARAMETERS;
    }

    public static Set<String> getStaticInputPorts() {
        return STATIC_INPUTPORTS;
    }

    public static Set<String> getStaticOutputPorts() {
        return STATIC_OUTPUTPORTS;
    }

    public void doWork() throws OperatorException {
        Table build;
        try {
            reloadDeclaration();
            this.tableTimeSeriesHelper.resetHelper();
            this.tableTimeSeriesHelper.readInputData(this.tableTimeSeriesHelper.getTableInputPort());
            IOTable inputIoTable = this.tableTimeSeriesHelper.getInputIoTable();
            this.dataOut.deliver(inputIoTable);
            ISeries inputISeriesFromPort = this.tableTimeSeriesHelper.getInputISeriesFromPort();
            Table table = inputIoTable.getTable();
            List list = (List) table.labels().stream().filter(str -> {
                return (str.equals(inputISeriesFromPort.getIndexName()) || str.equals(inputISeriesFromPort.getSeriesNames()[0])) ? false : true;
            }).collect(Collectors.toList());
            list.add(0, inputISeriesFromPort.getSeriesNames()[0]);
            Table columns = table.columns(list);
            if (getParameterAsBoolean("has_indices")) {
                build = table.columns(Arrays.asList(inputISeriesFromPort.getIndexName()));
            } else {
                IndexDimension indexDimension = inputISeriesFromPort.getIndexDimension();
                NumericBuffer realBuffer = Buffers.realBuffer(inputISeriesFromPort.getLength());
                for (int i = 0; i < inputISeriesFromPort.getLength(); i++) {
                    realBuffer.set(i, indexDimension.getIndexValueAsDouble(i));
                }
                build = Builders.newTableBuilder(inputISeriesFromPort.getLength()).add(inputISeriesFromPort.getIndexDimension().getName(), realBuffer.toColumn()).build(this.context);
            }
            IOObject bufferedFileObject = new BufferedFileObject(compileParametersAsJson(new PythonOperator.ParameterKeyValue(PARAMETER_INDEX_NAME, inputISeriesFromPort.getIndexName()), new PythonOperator.ParameterKeyValue(PARAMETER_SERIES_NAME, inputISeriesFromPort.getSeriesNames()[0])));
            String trainingScript = getTrainingScript();
            List<IOObject> arrayList = new ArrayList<>(Arrays.asList(BeltConverter.convertSequentially(new IOTable(build)), BeltConverter.convertSequentially(new IOTable(columns))));
            ScriptRunner scriptRunner = EnvironmentTools.getScriptRunner(this, trainingScript, null, false);
            scriptRunner.registerLogger(LogService.getRoot());
            List<IOObject> list2 = (List) checkInputTypes(scriptRunner).stream().skip(NUMBER_OF_STATIC_INPUTPORTS).collect(Collectors.toList());
            checkExampleSet(list2);
            arrayList.addAll(list2);
            arrayList.add(bufferedFileObject);
            OutputPorts outputPorts = getOutputPorts();
            int numberOfPorts = outputPorts.getNumberOfPorts();
            try {
                Objects.requireNonNull(scriptRunner);
                OperatorSentinel.Sentinel scheduleSentinel = OperatorSentinel.scheduleSentinel(this, scriptRunner::cancel);
                try {
                    List<IOObject> run = scriptRunner.run(arrayList, numberOfPorts);
                    if (scheduleSentinel != null) {
                        scheduleSentinel.close();
                    }
                    if (run.size() != numberOfPorts) {
                        throw new UserError(this, "python_scripting.mismatching_outputs", new Object[]{Integer.valueOf(numberOfPorts), Integer.valueOf(run.size())});
                    }
                    try {
                        PythonNativeObject pythonNativeObject = (PythonNativeObject) run.get(0);
                        PythonForecast pythonForecast = new PythonForecast(getDeclaration().getName(), ((ExampleSet) run.get(1)).getAttributes().get("description").getMapping().mapIndex(0), getScript(), Files.readAllBytes(pythonNativeObject.getFile().toPath()), this);
                        setWrapper(pythonForecast);
                        this.modelOutput.deliver(new PythonForecastModel(pythonForecast, inputISeriesFromPort, getIndicesParameter()));
                        for (int i2 = NUMBER_OF_STATIC_OUTPUTPORTS; i2 < numberOfPorts; i2++) {
                            outputPorts.getPortByIndex(i2).deliver(run.get(i2));
                        }
                    } catch (IOException e) {
                        throw new UserError(this, e, "python_scripting.failed_to_read_result");
                    } catch (ClassCastException e2) {
                        throw new UserError(this, e2, "python_scripting.invalid_training_result_type");
                    }
                } finally {
                }
            } catch (OperatorException e3) {
                throw e3;
            } catch (Exception e4) {
                checkForStop();
                throw new UserError(this, e4, "python_scripting.python_runner_error");
            }
        } catch (IOException e5) {
            throw new UserError(this, e5, "python_scripting.loading_json_failed");
        }
    }

    @Override // com.rapidminer.extension.pythonscripting.operator.PythonOperator
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = this.tableTimeSeriesHelper.getParameterTypes(new ArrayList());
        parameterTypes.addAll(super.getParameterTypes());
        return parameterTypes;
    }

    @Override // com.rapidminer.extension.pythonscripting.operator.PythonOperator
    public String getDefaultDeclaration() {
        return DECLARATION_TEMPLATE;
    }

    @Override // com.rapidminer.extension.pythonscripting.operator.PythonOperator
    public String getDefaultScript() {
        return SCRIPT_TEMPLATE;
    }

    @Override // com.rapidminer.extension.pythonscripting.operator.PythonOperator
    public String getType() {
        return TYPE;
    }

    private String getTrainingScript() throws UndefinedParameterError {
        return getScript() + ENTRYPOINT;
    }

    static {
        STATIC_INPUTPORTS.add(INPUT_PORT_EXAMPLE_SET);
        STATIC_OUTPUTPORTS.add(OUTPUT_PORT_MODEL);
        STATIC_OUTPUTPORTS.add(OUTPUT_PORT_ORIGINAL);
        NUMBER_OF_STATIC_INPUTPORTS = STATIC_INPUTPORTS.size();
        NUMBER_OF_STATIC_OUTPUTPORTS = STATIC_OUTPUTPORTS.size();
        STATIC_PARAMETERS.add("time_series_attribute");
        STATIC_PARAMETERS.add("has_indices");
        STATIC_PARAMETERS.add("indices_attribute");
        STATIC_PARAMETERS.add(PARAMETER_INDEX_NAME);
        STATIC_PARAMETERS.add(PARAMETER_SERIES_NAME);
        STATIC_PARAMETERS.add("sort_time_series");
        ENTRYPOINT = ConfigurationTools.loadTextFile("scripts/forecaster_train_entrypoint.py");
        DECLARATION_TEMPLATE = ConfigurationTools.loadTextFile("scripts/forecaster_declaration_template.json");
        SCRIPT_TEMPLATE = ConfigurationTools.loadTextFile("scripts/forecaster_definition_template.py");
    }
}
