package com.rapidminer.extension.pythonscripting.operator;

import com.rapidminer.adaption.belt.ContextAdapter;
import com.rapidminer.adaption.belt.IOTable;
import com.rapidminer.belt.buffer.Buffers;
import com.rapidminer.belt.buffer.NumericBuffer;
import com.rapidminer.belt.execution.Context;
import com.rapidminer.belt.table.BeltConverter;
import com.rapidminer.belt.table.Builders;
import com.rapidminer.belt.table.Table;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.extension.pythonscripting.definition.ConfigurationTools;
import com.rapidminer.extension.pythonscripting.operator.OperatorSentinel;
import com.rapidminer.extension.pythonscripting.operator.PythonOperator;
import com.rapidminer.extension.pythonscripting.operator.scripting.ScriptRunner;
import com.rapidminer.extension.pythonscripting.operator.scripting.python.PythonScriptingOperator;
import com.rapidminer.extension.timeseries.operator.TableTimeSeriesOperator;
import com.rapidminer.extension.timeseries.operator.helper.TableTimeSeriesHelper;
import com.rapidminer.extension.timeseries.operator.helper.TimeSeriesHelperBuilder;
import com.rapidminer.extension.timeseries.operator.helper.WrongConfiguredHelperException;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.OperatorVersion;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.nio.file.BufferedFileObject;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.OutputPorts;
import com.rapidminer.operator.ports.metadata.PassThroughRule;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.search.MarketplaceGlobalSearchManager;
import com.rapidminer.studio.concurrency.internal.SequentialConcurrencyContext;
import com.rapidminer.timeseriesanalysis.datamodel.dimension.interfaces.IndexDimension;
import com.rapidminer.timeseriesanalysis.datamodel.series.SeriesBuilder;
import com.rapidminer.timeseriesanalysis.datamodel.series.interfaces.ISeries;
import com.rapidminer.tools.LogService;
import java.io.IOException;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:com/rapidminer/extension/pythonscripting/operator/PythonForecaster.class */
public class PythonForecaster extends PythonOperator {
    public static final String TYPE = "forecaster";
    public static final String INPUT_PORT_EXAMPLE_SET = "example set";
    public static final String OUTPUT_PORT_MODEL = "model";
    public static final String OUTPUT_PORT_ORIGINAL = "original";
    private static final String PARAMETER_INDEX_NAME = "index_name";
    private static final String PARAMETER_SERIES_NAME = "series_name";
    private static final String ENTRYPOINT = loadEntrypointScript();
    private static final String DECLARATION_TEMPLATE = loadDeclarationTemplate();
    private static final String SCRIPT_TEMPLATE = loadScriptTemplate();
    private final OutputPort modelOutput;
    private final OutputPort dataOut;
    private final Context context;
    private TableTimeSeriesHelper<Operator, ISeries<Double, Double>> tableTimeSeriesHelper;

    public PythonForecaster(OperatorDescription operatorDescription) throws WrongConfiguredHelperException {
        this(operatorDescription, true);
    }

    public PythonForecaster(OperatorDescription operatorDescription, boolean z) throws WrongConfiguredHelperException {
        super(operatorDescription, true, getStaticInputPorts().size(), getStaticOutputPorts().size());
        this.modelOutput = getOutputPorts().createPort(OUTPUT_PORT_MODEL);
        this.dataOut = getOutputPorts().createPort(OUTPUT_PORT_ORIGINAL);
        this.context = ContextAdapter.adapt(new SequentialConcurrencyContext());
        this.tableTimeSeriesHelper = initializeTimeSeriesHelper(operatorDescription);
        setupTransformer();
        if (z) {
            tryReloadDeclaration();
        }
    }

    protected PythonForecaster(OperatorDescription operatorDescription, String str, String str2) throws WrongConfiguredHelperException {
        this(operatorDescription, false);
        setImmutable(str, str2);
    }

    public OperatorVersion[] getIncompatibleVersionChanges() {
        ArrayList arrayList = new ArrayList(Arrays.asList(super.getIncompatibleVersionChanges()));
        arrayList.add(PythonScriptingOperator.VERSION_HDF5_DATE_TIME_BUG);
        arrayList.add(PythonScriptingOperator.VERSION_ARROW_SERIALIZATION);
        return (OperatorVersion[]) arrayList.toArray(new OperatorVersion[0]);
    }

    public String getIndicesParameter() throws UndefinedParameterError {
        return getParameterAsBoolean("has_indices") ? getParameterAsString("indices_attribute") : "";
    }

    public void doWork() throws OperatorException {
        try {
            reloadDeclaration();
            this.tableTimeSeriesHelper.resetHelper();
            this.tableTimeSeriesHelper.readInputData(this.tableTimeSeriesHelper.getTableInputPort());
            IOTable inputIoTable = this.tableTimeSeriesHelper.getInputIoTable();
            this.dataOut.deliver(inputIoTable);
            ISeries<Double, Double> inputISeriesFromPort = this.tableTimeSeriesHelper.getInputISeriesFromPort();
            Table extractValueTable = extractValueTable(inputIoTable.getTable(), inputISeriesFromPort);
            Table extractIndexTable = extractIndexTable(inputIoTable.getTable(), inputISeriesFromPort);
            BufferedFileObject compileParameters = compileParameters(inputISeriesFromPort);
            String trainingScript = getTrainingScript();
            processScriptResults(executeScript(createScriptRunner(trainingScript), prepareScriptInputs(extractIndexTable, extractValueTable, compileParameters)), inputISeriesFromPort);
        } catch (IOException e) {
            throw new UserError(this, e, "python_scripting.loading_json_failed");
        }
    }

    private Table extractValueTable(Table table, ISeries<Double, Double> iSeries) {
        List list = (List) table.labels().stream().filter(str -> {
            return (str.equals(iSeries.getIndexName()) || str.equals(iSeries.getSeriesNames()[0])) ? false : true;
        }).collect(Collectors.toList());
        list.add(0, iSeries.getSeriesNames()[0]);
        return table.columns(list);
    }

    private Table extractIndexTable(Table table, ISeries<Double, Double> iSeries) {
        return getParameterAsBoolean("has_indices") ? table.columns(Collections.singletonList(iSeries.getIndexName())) : createIndexTable(iSeries);
    }

    private Table createIndexTable(ISeries<Double, Double> iSeries) {
        IndexDimension indexDimension = iSeries.getIndexDimension();
        NumericBuffer realBuffer = Buffers.realBuffer(iSeries.getLength());
        for (int i = 0; i < iSeries.getLength(); i++) {
            realBuffer.set(i, indexDimension.getIndexValueAsDouble(i));
        }
        return Builders.newTableBuilder(iSeries.getLength()).add(indexDimension.getName(), realBuffer.toColumn()).build(this.context);
    }

    private BufferedFileObject compileParameters(ISeries<Double, Double> iSeries) throws OperatorException {
        return new BufferedFileObject(compileParametersAsJson(new PythonOperator.ParameterKeyValue(PARAMETER_INDEX_NAME, iSeries.getIndexName()), new PythonOperator.ParameterKeyValue(PARAMETER_SERIES_NAME, iSeries.getSeriesNames()[0])));
    }

    private List<IOObject> prepareScriptInputs(Table table, Table table2, BufferedFileObject bufferedFileObject) throws UserError {
        ArrayList arrayList = new ArrayList();
        arrayList.add(BeltConverter.convertSequentially(new IOTable(table)));
        arrayList.add(BeltConverter.convertSequentially(new IOTable(table2)));
        List<IOObject> list = (List) checkInputTypes(createScriptRunner(getTrainingScript())).stream().skip(getStaticInputPorts().size()).collect(Collectors.toList());
        checkExampleSet(list);
        arrayList.addAll(list);
        arrayList.add(bufferedFileObject);
        return arrayList;
    }

    private ScriptRunner createScriptRunner(String str) throws UserError {
        ScriptRunner scriptRunner = EnvironmentTools.getScriptRunner(this, str, null, false, getCompatibilityLevel().isAbove(PythonScriptingOperator.VERSION_ARROW_SERIALIZATION));
        scriptRunner.registerLogger(LogService.getRoot());
        return scriptRunner;
    }

    private List<IOObject> executeScript(ScriptRunner scriptRunner, List<IOObject> list) throws OperatorException {
        int numberOfPorts = getOutputPorts().getNumberOfPorts();
        try {
            Objects.requireNonNull(scriptRunner);
            OperatorSentinel.Sentinel scheduleSentinel = OperatorSentinel.scheduleSentinel(this, scriptRunner::cancel);
            try {
                List<IOObject> run = scriptRunner.run(list, numberOfPorts);
                if (scheduleSentinel != null) {
                    scheduleSentinel.close();
                }
                return run;
            } finally {
            }
        } catch (Exception e) {
            checkForStop();
            throw new UserError(this, e, "python_scripting.python_runner_error");
        }
    }

    private void processScriptResults(List<IOObject> list, ISeries<Double, Double> iSeries) throws OperatorException {
        if (list.size() != getOutputPorts().getNumberOfPorts()) {
            throw new UserError(this, "python_scripting.mismatching_outputs", new Object[]{Integer.valueOf(getOutputPorts().getNumberOfPorts()), Integer.valueOf(list.size())});
        }
        try {
            PythonForecast pythonForecast = new PythonForecast(getDeclaration().getName(), extractDescriptionText(list.get(1)), getScript(), Files.readAllBytes(list.get(0).getFile().toPath()), this);
            setWrapper(pythonForecast);
            this.modelOutput.deliver(new PythonForecastModel(pythonForecast, iSeries, getIndicesParameter()));
            deliverDynamicOutputs(list);
        } catch (IOException | ClassCastException e) {
            throw new UserError(this, e, "python_scripting.invalid_training_result_type");
        }
    }

    private String extractDescriptionText(IOObject iOObject) {
        if (iOObject instanceof ExampleSet) {
            return ((ExampleSet) iOObject).getAttributes().get(MarketplaceGlobalSearchManager.FIELD_DESCRIPTION).getMapping().mapIndex(0);
        }
        if (iOObject instanceof IOTable) {
            return ((IOTable) iOObject).getTable().column(MarketplaceGlobalSearchManager.FIELD_DESCRIPTION).getDictionary().get(1);
        }
        throw new IllegalArgumentException("Unsupported type for description: " + iOObject.getClass());
    }

    private void deliverDynamicOutputs(List<IOObject> list) {
        OutputPorts outputPorts = getOutputPorts();
        for (int size = getStaticOutputPorts().size(); size < outputPorts.getNumberOfPorts(); size++) {
            outputPorts.getPortByIndex(size).deliver(list.get(size));
        }
    }

    @Override // com.rapidminer.extension.pythonscripting.operator.PythonOperator
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = this.tableTimeSeriesHelper.getParameterTypes(new ArrayList());
        parameterTypes.addAll(super.getParameterTypes());
        return parameterTypes;
    }

    @Override // com.rapidminer.extension.pythonscripting.operator.PythonOperator
    public String getDefaultDeclaration() {
        return DECLARATION_TEMPLATE;
    }

    @Override // com.rapidminer.extension.pythonscripting.operator.PythonOperator
    public String getDefaultScript() {
        return SCRIPT_TEMPLATE;
    }

    @Override // com.rapidminer.extension.pythonscripting.operator.PythonOperator
    public String getType() {
        return TYPE;
    }

    private String getTrainingScript() throws UndefinedParameterError {
        return getScript() + ENTRYPOINT;
    }

    /* JADX WARN: Type inference failed for: r0v0, types: [com.rapidminer.extension.pythonscripting.operator.PythonForecaster$1PythonForecasterTableTimeSeriesHelper] */
    private TableTimeSeriesHelper<Operator, ISeries<Double, Double>> initializeTimeSeriesHelper(OperatorDescription operatorDescription) throws WrongConfiguredHelperException {
        return new TableTimeSeriesOperator<ISeries<Double, Double>>(operatorDescription) { // from class: com.rapidminer.extension.pythonscripting.operator.PythonForecaster.1PythonForecasterTableTimeSeriesHelper
            public TableTimeSeriesHelper<Operator, ISeries<Double, Double>> getTableTimeSeriesHelper() {
                return this.tableTimeSeriesHelper;
            }

            protected TableTimeSeriesHelper<Operator, ISeries<Double, Double>> initTableTimeSeriesOperator() throws WrongConfiguredHelperException {
                return new TimeSeriesHelperBuilder(PythonForecaster.this).asInputPortOperator(PythonForecaster.INPUT_PORT_EXAMPLE_SET).setIndiceHandling(getIndicesHandling()).setValuesType(getAllowedValuesType()).build();
            }

            protected TableTimeSeriesHelper.IndicesHandling getIndicesHandling() {
                return TableTimeSeriesHelper.IndicesHandling.OPTIONAL_INDICES;
            }

            protected SeriesBuilder.ValuesType getAllowedValuesType() {
                return SeriesBuilder.ValuesType.MIXED;
            }
        }.getTableTimeSeriesHelper();
    }

    private void setupTransformer() throws WrongConfiguredHelperException {
        InputPort tableInputPort = this.tableTimeSeriesHelper.getTableInputPort();
        this.tableTimeSeriesHelper.addEmptyInputCheck();
        getTransformer().addGenerationRule(this.modelOutput, PythonForecastModel.class);
        getTransformer().addRule(new PassThroughRule(tableInputPort, this.dataOut, false));
    }

    private static String loadEntrypointScript() {
        return ConfigurationTools.loadTextFile("scripts/forecaster_train_entrypoint.py");
    }

    private static String loadDeclarationTemplate() {
        return ConfigurationTools.loadTextFile("scripts/forecaster_declaration_template.json");
    }

    private static String loadScriptTemplate() {
        return ConfigurationTools.loadTextFile("scripts/forecaster_definition_template.py");
    }

    public static Set<String> getStaticParameters() {
        HashSet hashSet = new HashSet(PythonOperator.getStaticParameters());
        hashSet.addAll(Arrays.asList("time_series_attribute", "has_indices", "indices_attribute", PARAMETER_INDEX_NAME, PARAMETER_SERIES_NAME, "sort_time_series"));
        return hashSet;
    }

    public static Set<String> getStaticInputPorts() {
        return new HashSet(List.of(INPUT_PORT_EXAMPLE_SET));
    }

    public static Set<String> getStaticOutputPorts() {
        return new HashSet(Arrays.asList(OUTPUT_PORT_MODEL, OUTPUT_PORT_ORIGINAL));
    }
}
