package com.rapidminer.extension.pythonscripting.operator;

import com.rapidminer.adaption.belt.IOTable;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeRole;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.extension.pythonscripting.definition.ConfigurationTools;
import com.rapidminer.extension.pythonscripting.definition.OperatorDeclaration;
import com.rapidminer.extension.pythonscripting.model.PythonModel;
import com.rapidminer.extension.pythonscripting.operator.PythonOperator;
import com.rapidminer.extension.pythonscripting.operator.scripting.ScriptRunner;
import com.rapidminer.extension.pythonscripting.operator.scripting.python.PythonNativeObject;
import com.rapidminer.extension.pythonscripting.operator.scripting.python.PythonScriptingOperator;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.OperatorVersion;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.CapabilityCheck;
import com.rapidminer.operator.learner.Learner;
import com.rapidminer.operator.nio.file.BufferedFileObject;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.OutputPorts;
import com.rapidminer.operator.ports.metadata.GeneratePredictionModelTransformationRule;
import com.rapidminer.operator.ports.metadata.LearnerPrecondition;
import com.rapidminer.operator.ports.metadata.PassThroughRule;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.search.MarketplaceGlobalSearchManager;
import java.io.IOException;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:com/rapidminer/extension/pythonscripting/operator/PythonLearner.class */
public class PythonLearner extends PythonOperator implements Learner {
    public static final String TYPE = "learner";
    private static final String OUTPUT_MODEL_NAME = "model";
    private static final String OUTPUT_EXAMPLE_SET_NAME = "example set";
    private final InputPort dataInput;
    private final OutputPort modelOutput;
    private final OutputPort dataOutput;
    private static final String SUPERVISED_TRAINING_ENTRYPOINT = loadEntrypointScript("scripts/supervised_learning_entrypoint.py");
    private static final String UNSUPERVISED_TRAINING_ENTRYPOINT = loadEntrypointScript("scripts/unsupervised_learning_entrypoint.py");
    private static final String DECLARATION_TEMPLATE = loadTemplate("scripts/learner_declaration_template.json");
    private static final String SCRIPT_TEMPLATE = loadTemplate("scripts/learner_definition_template.py");
    private static final String INPUT_TRAINING_SET_NAME = "training set";
    private static final Set<String> STATIC_INPUT_PORTS = Collections.singleton(INPUT_TRAINING_SET_NAME);
    private static final Set<String> STATIC_OUTPUT_PORTS = new HashSet(Arrays.asList("model", "example set"));
    private static final int NUMBER_OF_STATIC_INPUT_PORTS = STATIC_INPUT_PORTS.size();
    private static final int NUMBER_OF_STATIC_OUTPUT_PORTS = STATIC_OUTPUT_PORTS.size();

    public PythonLearner(OperatorDescription operatorDescription) {
        this(operatorDescription, true);
    }

    public PythonLearner(OperatorDescription operatorDescription, boolean z) {
        super(operatorDescription, true, NUMBER_OF_STATIC_INPUT_PORTS, NUMBER_OF_STATIC_OUTPUT_PORTS);
        this.dataInput = getInputPorts().createPort(INPUT_TRAINING_SET_NAME);
        this.modelOutput = getOutputPorts().createPort("model");
        this.dataOutput = getOutputPorts().createPort("example set");
        initializeOperator(z);
    }

    protected PythonLearner(OperatorDescription operatorDescription, String str, String str2) {
        this(operatorDescription, false);
        setImmutable(str, str2);
    }

    public OperatorVersion[] getIncompatibleVersionChanges() {
        ArrayList arrayList = new ArrayList(Arrays.asList(super.getIncompatibleVersionChanges()));
        arrayList.add(PythonScriptingOperator.VERSION_HDF5_DATE_TIME_BUG);
        arrayList.add(PythonScriptingOperator.VERSION_ARROW_SERIALIZATION);
        return (OperatorVersion[]) arrayList.toArray(new OperatorVersion[0]);
    }

    @Override // com.rapidminer.extension.pythonscripting.operator.PythonOperator
    public String getDefaultDeclaration() {
        return DECLARATION_TEMPLATE;
    }

    @Override // com.rapidminer.extension.pythonscripting.operator.PythonOperator
    public String getDefaultScript() {
        return SCRIPT_TEMPLATE;
    }

    @Override // com.rapidminer.extension.pythonscripting.operator.PythonOperator
    public String getType() {
        return TYPE;
    }

    public static Set<String> getStaticInputPorts() {
        return STATIC_INPUT_PORTS;
    }

    public static Set<String> getStaticOutputPorts() {
        return STATIC_OUTPUT_PORTS;
    }

    public void doWork() throws OperatorException {
        ExampleSet exampleSet = (ExampleSet) this.dataInput.getData(ExampleSet.class);
        try {
            reloadDeclaration();
            validateInputData(exampleSet);
            this.modelOutput.deliver(learn(exampleSet));
            this.dataOutput.deliver(exampleSet);
        } catch (IOException e) {
            throw new UserError(this, e, "python_scripting.loading_json_failed");
        }
    }

    public Model learn(ExampleSet exampleSet) throws OperatorException {
        boolean isUnsupervisedLearning = isUnsupervisedLearning();
        ExampleSet prepareFeatureSet = prepareFeatureSet(exampleSet, isUnsupervisedLearning);
        ExampleSet prepareLabelSet = isUnsupervisedLearning ? null : prepareLabelSet(exampleSet);
        ExampleSet prepareTrainingSet = prepareTrainingSet(prepareFeatureSet, prepareLabelSet, isUnsupervisedLearning);
        BufferedFileObject bufferedFileObject = new BufferedFileObject(compileParametersAsJson(new PythonOperator.ParameterKeyValue[0]));
        ScriptRunner createScriptRunner = createScriptRunner(isUnsupervisedLearning);
        return wrapModel(prepareTrainingSet, executeScript(createScriptRunner, prepareScriptInputs(prepareFeatureSet, prepareLabelSet, prepareDynamicInputs(createScriptRunner), bufferedFileObject, isUnsupervisedLearning)));
    }

    private void validateInputData(ExampleSet exampleSet) throws OperatorException {
        if (!isUnsupervisedLearning()) {
            Tools.isLabelled(exampleSet);
        }
        Tools.isNonEmpty(exampleSet);
        Tools.hasRegularAttributes(exampleSet);
        new CapabilityCheck(this, false).checkLearnerCapabilities(this, exampleSet);
    }

    private boolean isUnsupervisedLearning() {
        return supportsCapability(OperatorCapability.NO_LABEL);
    }

    private ExampleSet prepareFeatureSet(ExampleSet exampleSet, boolean z) {
        ExampleSet exampleSet2 = (ExampleSet) exampleSet.clone();
        Attributes attributes = exampleSet2.getAttributes();
        OperatorDeclaration declaration = getDeclaration();
        if (!z) {
            Attribute label = attributes.getLabel();
            if (label != null) {
                attributes.remove(label);
            }
            if (declaration.isDropSpecial()) {
                Iterator allAttributeRoles = attributes.allAttributeRoles();
                ArrayList arrayList = new ArrayList();
                while (allAttributeRoles.hasNext()) {
                    AttributeRole attributeRole = (AttributeRole) allAttributeRoles.next();
                    if (attributeRole.isSpecial() && !"label".equals(attributeRole.getSpecialName())) {
                        arrayList.add(attributeRole.getAttribute());
                    }
                }
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    attributes.remove((Attribute) it.next());
                }
            }
        } else if (declaration.isDropSpecial()) {
            attributes.clearSpecial();
        }
        return exampleSet2;
    }

    private ExampleSet prepareLabelSet(ExampleSet exampleSet) {
        ExampleSet exampleSet2 = (ExampleSet) exampleSet.clone();
        Attributes attributes = exampleSet2.getAttributes();
        Iterator allAttributeRoles = attributes.allAttributeRoles();
        ArrayList arrayList = new ArrayList();
        while (allAttributeRoles.hasNext()) {
            AttributeRole attributeRole = (AttributeRole) allAttributeRoles.next();
            if (!"label".equals(attributeRole.getSpecialName())) {
                arrayList.add(attributeRole.getAttribute());
            }
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            attributes.remove((Attribute) it.next());
        }
        return exampleSet2;
    }

    private ExampleSet prepareTrainingSet(ExampleSet exampleSet, ExampleSet exampleSet2, boolean z) {
        if (z) {
            return exampleSet;
        }
        ExampleSet exampleSet3 = (ExampleSet) exampleSet.clone();
        Attributes attributes = exampleSet2.getAttributes();
        exampleSet3.getAttributes().add(attributes.getRole(attributes.getLabel()));
        return exampleSet3;
    }

    private ScriptRunner createScriptRunner(boolean z) throws UserError {
        return EnvironmentTools.getScriptRunner(this, getTrainingScript(z), null, false, getCompatibilityLevel().isAbove(PythonScriptingOperator.VERSION_ARROW_SERIALIZATION));
    }

    private List<IOObject> prepareDynamicInputs(ScriptRunner scriptRunner) throws OperatorException {
        List<IOObject> list = (List) checkInputTypes(scriptRunner).stream().skip(NUMBER_OF_STATIC_INPUT_PORTS).collect(Collectors.toList());
        checkExampleSet(list);
        return list;
    }

    private List<IOObject> prepareScriptInputs(ExampleSet exampleSet, ExampleSet exampleSet2, List<IOObject> list, BufferedFileObject bufferedFileObject, boolean z) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(exampleSet);
        if (!z) {
            arrayList.add(exampleSet2);
        }
        arrayList.addAll(list);
        arrayList.add(bufferedFileObject);
        return arrayList;
    }

    private List<IOObject> executeScript(ScriptRunner scriptRunner, List<IOObject> list) throws OperatorException {
        int numberOfPorts = getOutputPorts().getNumberOfPorts();
        try {
            List<IOObject> run = scriptRunner.run(list, numberOfPorts);
            if (run.size() != numberOfPorts) {
                throw new UserError(this, "python_scripting.mismatching_outputs", new Object[]{Integer.valueOf(numberOfPorts), Integer.valueOf(run.size())});
            }
            return run;
        } catch (Exception e) {
            checkForStop();
            throw new UserError(this, e, "python_scripting.python_runner_error");
        }
    }

    private PythonModel wrapModel(ExampleSet exampleSet, List<IOObject> list) throws OperatorException {
        try {
            PythonNativeObject pythonNativeObject = list.get(0);
            PythonModel pythonModel = new PythonModel(getDeclaration().getName(), extractDescriptionText(list.get(1)), getScript(), Files.readAllBytes(pythonNativeObject.getFile().toPath()), exampleSet);
            setWrapper(pythonModel);
            deliverDynamicOutputs(list);
            return pythonModel;
        } catch (IOException | ClassCastException e) {
            throw new UserError(this, e, "python_scripting.failed_to_wrap_model");
        }
    }

    private String extractDescriptionText(IOObject iOObject) {
        if (iOObject instanceof ExampleSet) {
            return ((ExampleSet) iOObject).getAttributes().get(MarketplaceGlobalSearchManager.FIELD_DESCRIPTION).getMapping().mapIndex(0);
        }
        if (iOObject instanceof IOTable) {
            return ((IOTable) iOObject).getTable().column(MarketplaceGlobalSearchManager.FIELD_DESCRIPTION).getDictionary().get(1);
        }
        throw new IllegalArgumentException("Unsupported type for description: " + iOObject.getClass());
    }

    private void deliverDynamicOutputs(List<IOObject> list) {
        OutputPorts outputPorts = getOutputPorts();
        for (int i = NUMBER_OF_STATIC_OUTPUT_PORTS; i < outputPorts.getNumberOfPorts(); i++) {
            outputPorts.getPortByIndex(i).deliver(list.get(i));
        }
    }

    public boolean shouldEstimatePerformance() {
        return false;
    }

    public PerformanceVector getEstimatedPerformance() throws OperatorException {
        throw new UserError(this, 912, new Object[]{getName(), "estimation of performance not supported."});
    }

    public boolean shouldCalculateWeights() {
        return false;
    }

    public AttributeWeights getWeights(ExampleSet exampleSet) throws OperatorException {
        throw new UserError(this, 912, new Object[]{getName(), "computation of weights not supported."});
    }

    public boolean supportsCapability(OperatorCapability operatorCapability) {
        OperatorDeclaration declaration = getDeclaration();
        if (declaration == null || declaration.getCapabilities() == null || operatorCapability == null) {
            return false;
        }
        return declaration.getCapabilities().contains(operatorCapability.getDescription());
    }

    private String getTrainingScript(boolean z) throws UndefinedParameterError {
        return getScript() + (z ? UNSUPERVISED_TRAINING_ENTRYPOINT : SUPERVISED_TRAINING_ENTRYPOINT);
    }

    private void initializeOperator(boolean z) {
        this.dataInput.addPrecondition(new LearnerPrecondition(this, this.dataInput));
        getTransformer().addRule(new GeneratePredictionModelTransformationRule(this.dataInput, this.modelOutput, PythonModel.class));
        getTransformer().addRule(new PassThroughRule(this.dataInput, this.dataOutput, false));
        if (z) {
            tryReloadDeclaration();
        }
    }

    private static String loadEntrypointScript(String str) {
        return ConfigurationTools.loadTextFile(str);
    }

    private static String loadTemplate(String str) {
        return ConfigurationTools.loadTextFile(str);
    }
}
