package com.rapidminer.extension.pythonscripting.operator;

import com.rapidminer.example.AttributeRole;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.extension.pythonscripting.definition.ConfigurationTools;
import com.rapidminer.extension.pythonscripting.definition.OperatorDeclaration;
import com.rapidminer.extension.pythonscripting.model.PythonModel;
import com.rapidminer.extension.pythonscripting.operator.scripting.python.PythonNativeObject;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.ProcessStoppedException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.CapabilityCheck;
import com.rapidminer.operator.learner.Learner;
import com.rapidminer.operator.nio.file.BufferedFileObject;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.GeneratePredictionModelTransformationRule;
import com.rapidminer.operator.ports.metadata.LearnerPrecondition;
import com.rapidminer.operator.ports.metadata.PassThroughRule;
import com.rapidminer.parameter.UndefinedParameterError;
import java.io.IOException;
import java.nio.file.Files;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:com/rapidminer/extension/pythonscripting/operator/PythonLearner.class */
public class PythonLearner extends PythonOperator implements Learner {
    public static final String TYPE = "learner";
    private static final String SUPERVISED_TRAINING_ENTRYPOINT = ConfigurationTools.loadTextFile("scripts/supervised_learning_entrypoint.py");
    private static final String UNSUPERVISED_TRAINING_ENTRYPOINT = ConfigurationTools.loadTextFile("scripts/unsupervised_learning_entrypoint.py");
    private static final String DECLARATION_TEMPLATE = ConfigurationTools.loadTextFile("scripts/learner_declaration_template.json");
    private static final String SCRIPT_TEMPLATE = ConfigurationTools.loadTextFile("scripts/learner_definition_template.py");
    private final InputPort dataInput;
    private final OutputPort modelOutput;
    private final OutputPort dataOut;

    public PythonLearner(OperatorDescription operatorDescription) {
        super(operatorDescription, false);
        this.dataInput = getInputPorts().createPort("training set");
        this.modelOutput = getOutputPorts().createPort("model");
        this.dataOut = getOutputPorts().createPort("example set");
        this.dataInput.addPrecondition(new LearnerPrecondition(this, this.dataInput));
        getTransformer().addRule(new GeneratePredictionModelTransformationRule(this.dataInput, this.modelOutput, PythonModel.class));
        getTransformer().addRule(new PassThroughRule(this.dataInput, this.dataOut, false));
    }

    protected PythonLearner(OperatorDescription operatorDescription, String str, String str2) {
        this(operatorDescription);
        setImmutable(str, str2);
    }

    @Override // com.rapidminer.extension.pythonscripting.operator.PythonOperator
    public String getDefaultDeclaration() {
        return DECLARATION_TEMPLATE;
    }

    @Override // com.rapidminer.extension.pythonscripting.operator.PythonOperator
    public String getDefaultScript() {
        return SCRIPT_TEMPLATE;
    }

    @Override // com.rapidminer.extension.pythonscripting.operator.PythonOperator
    public String getType() {
        return TYPE;
    }

    public void doWork() throws OperatorException {
        ExampleSet exampleSet = (ExampleSet) this.dataInput.getData(ExampleSet.class);
        try {
            reloadDeclaration();
            if (!supportsCapability(OperatorCapability.NO_LABEL)) {
                Tools.isLabelled(exampleSet);
            }
            Tools.isNonEmpty(exampleSet);
            Tools.hasRegularAttributes(exampleSet);
            new CapabilityCheck(this, false).checkLearnerCapabilities(this, exampleSet);
            this.modelOutput.deliver(learn(exampleSet));
            this.dataOut.deliver(exampleSet);
        } catch (IOException e) {
            throw new UserError(this, e, "python_scripting.loading_json_failed");
        }
    }

    public Model learn(ExampleSet exampleSet) throws OperatorException {
        ExampleSet exampleSet2;
        boolean supportsCapability = supportsCapability(OperatorCapability.NO_LABEL);
        ExampleSet exampleSet3 = (ExampleSet) exampleSet.clone();
        IOObject iOObject = null;
        OperatorDeclaration declaration = getDeclaration();
        if (supportsCapability) {
            if (declaration.isDropSpecial()) {
                exampleSet3.getAttributes().clearSpecial();
            }
            exampleSet2 = exampleSet3;
        } else {
            Attributes attributes = exampleSet3.getAttributes();
            iOObject = (ExampleSet) exampleSet.clone();
            Attributes attributes2 = iOObject.getAttributes();
            Iterator allAttributeRoles = exampleSet.getAttributes().allAttributeRoles();
            while (allAttributeRoles.hasNext()) {
                AttributeRole attributeRole = (AttributeRole) allAttributeRoles.next();
                if ("label".equals(attributeRole.getSpecialName())) {
                    attributes.remove(attributeRole.getAttribute());
                } else {
                    attributes2.remove(attributeRole.getAttribute());
                    if (attributeRole.isSpecial() && declaration.isDropSpecial()) {
                        attributes.remove(attributeRole.getAttribute());
                    }
                }
            }
            exampleSet2 = (ExampleSet) exampleSet3.clone();
            exampleSet2.getAttributes().add(attributes2.getRole(attributes2.getLabel()));
        }
        IOObject bufferedFileObject = new BufferedFileObject(compileParametersAsJson());
        return wrapModel(declaration.getName(), exampleSet2, runLearner(getTrainingScript(supportsCapability), supportsCapability ? Arrays.asList(exampleSet3, bufferedFileObject) : Arrays.asList(exampleSet3, iOObject, bufferedFileObject)));
    }

    private List<IOObject> runLearner(String str, List<IOObject> list) throws UserError, ProcessStoppedException {
        try {
            List<IOObject> run = EnvironmentTools.getScriptRunner(this, str, null, false).run(list, 2);
            if (run.size() != 2) {
                throw new UserError(this, "python_scripting.invalid_training_result_size");
            }
            return run;
        } catch (UserError e) {
            throw e;
        } catch (Exception e2) {
            checkForStop();
            throw new UserError(this, e2, "python_scripting.python_runner_error");
        }
    }

    private PythonModel wrapModel(String str, ExampleSet exampleSet, List<IOObject> list) throws OperatorException {
        try {
            PythonNativeObject pythonNativeObject = list.get(0);
            ExampleSet exampleSet2 = list.get(1);
            PythonModel pythonModel = new PythonModel(str, exampleSet2.getAttributes().get("description").getMapping().mapIndex(0), getScript(), Files.readAllBytes(pythonNativeObject.getFile().toPath()), exampleSet);
            pythonModel.setDefaultPython(getParameterAsBoolean(EnvironmentTools.PARAMETER_USE_DEFAULT_PYTHON));
            if (!pythonModel.isDefaultPython()) {
                String parameterAsString = getParameterAsString(EnvironmentTools.PARAMETER_PACKAGE_MANAGER);
                pythonModel.setPackageManager(parameterAsString);
                if (parameterAsString.equals("conda (anaconda)")) {
                    pythonModel.setEnvironment(getParameterAsString(EnvironmentTools.PARAMETER_CONDA_ENVIRONMENT));
                } else if (parameterAsString.equals("virtualenvwrapper")) {
                    pythonModel.setEnvironment(getParameterAsString(EnvironmentTools.PARAMETER_VENVW_ENVIRONMENT));
                } else {
                    pythonModel.setEnvironment(getParameterAsString(EnvironmentTools.PARAMETER_PYTHON_BINARY));
                }
            }
            return pythonModel;
        } catch (IOException e) {
            throw new UserError(this, e, "python_scripting.failed_to_read_result");
        } catch (ClassCastException e2) {
            throw new UserError(this, e2, "python_scripting.invalid_training_result_type");
        }
    }

    public boolean shouldEstimatePerformance() {
        return false;
    }

    public PerformanceVector getEstimatedPerformance() throws OperatorException {
        throw new UserError(this, 912, new Object[]{getName(), "estimation of performance not supported."});
    }

    public boolean shouldCalculateWeights() {
        return false;
    }

    public AttributeWeights getWeights(ExampleSet exampleSet) throws OperatorException {
        throw new UserError(this, 912, new Object[]{getName(), "computation of weights not supported."});
    }

    public boolean supportsCapability(OperatorCapability operatorCapability) {
        OperatorDeclaration declaration = getDeclaration();
        if (declaration == null || declaration.getCapabilities() == null || operatorCapability == null) {
            return false;
        }
        return declaration.getCapabilities().contains(operatorCapability.getDescription());
    }

    private String getTrainingScript(boolean z) throws UndefinedParameterError {
        return getScript() + (z ? UNSUPERVISED_TRAINING_ENTRYPOINT : SUPERVISED_TRAINING_ENTRYPOINT);
    }
}
