/*
 * Decompiled with CFR 0.152.
 */
package com.rapidminer.extension.pythonscripting.operator;

import com.rapidminer.adaption.belt.IOTable;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.AttributeRole;
import com.rapidminer.example.AttributeWeights;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.extension.pythonscripting.definition.ConfigurationTools;
import com.rapidminer.extension.pythonscripting.definition.OperatorDeclaration;
import com.rapidminer.extension.pythonscripting.model.PythonModel;
import com.rapidminer.extension.pythonscripting.operator.EnvironmentTools;
import com.rapidminer.extension.pythonscripting.operator.PythonOperator;
import com.rapidminer.extension.pythonscripting.operator.scripting.ScriptRunner;
import com.rapidminer.extension.pythonscripting.operator.scripting.python.PythonNativeObject;
import com.rapidminer.extension.pythonscripting.operator.scripting.python.PythonScriptingOperator;
import com.rapidminer.gui.tools.VersionNumber;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Model;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorCapability;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.OperatorVersion;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.CapabilityCheck;
import com.rapidminer.operator.learner.CapabilityProvider;
import com.rapidminer.operator.learner.Learner;
import com.rapidminer.operator.nio.file.BufferedFileObject;
import com.rapidminer.operator.performance.PerformanceVector;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.OutputPorts;
import com.rapidminer.operator.ports.metadata.GeneratePredictionModelTransformationRule;
import com.rapidminer.operator.ports.metadata.LearnerPrecondition;
import com.rapidminer.operator.ports.metadata.MDTransformationRule;
import com.rapidminer.operator.ports.metadata.PassThroughRule;
import com.rapidminer.operator.ports.metadata.Precondition;
import com.rapidminer.parameter.UndefinedParameterError;
import java.io.IOException;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class PythonLearner
extends PythonOperator
implements Learner {
    public static final String TYPE = "learner";
    private static final String SUPERVISED_TRAINING_ENTRYPOINT = PythonLearner.loadEntrypointScript("scripts/supervised_learning_entrypoint.py");
    private static final String UNSUPERVISED_TRAINING_ENTRYPOINT = PythonLearner.loadEntrypointScript("scripts/unsupervised_learning_entrypoint.py");
    private static final String DECLARATION_TEMPLATE = PythonLearner.loadTemplate("scripts/learner_declaration_template.json");
    private static final String SCRIPT_TEMPLATE = PythonLearner.loadTemplate("scripts/learner_definition_template.py");
    private static final String INPUT_TRAINING_SET_NAME = "training set";
    private static final String OUTPUT_MODEL_NAME = "model";
    private static final String OUTPUT_EXAMPLE_SET_NAME = "example set";
    private static final Set<String> STATIC_INPUT_PORTS = Collections.singleton("training set");
    private static final Set<String> STATIC_OUTPUT_PORTS = new HashSet<String>(Arrays.asList("model", "example set"));
    private static final int NUMBER_OF_STATIC_INPUT_PORTS = STATIC_INPUT_PORTS.size();
    private static final int NUMBER_OF_STATIC_OUTPUT_PORTS = STATIC_OUTPUT_PORTS.size();
    private final InputPort dataInput = (InputPort)this.getInputPorts().createPort("training set");
    private final OutputPort modelOutput = (OutputPort)this.getOutputPorts().createPort("model");
    private final OutputPort dataOutput = (OutputPort)this.getOutputPorts().createPort("example set");

    public PythonLearner(OperatorDescription description) {
        this(description, true);
    }

    public PythonLearner(OperatorDescription description, boolean reloadDeclaration) {
        super(description, true, NUMBER_OF_STATIC_INPUT_PORTS, NUMBER_OF_STATIC_OUTPUT_PORTS);
        this.initializeOperator(reloadDeclaration);
    }

    protected PythonLearner(OperatorDescription description, String declaration, String definition) {
        this(description, false);
        this.setImmutable(declaration, definition);
    }

    public OperatorVersion[] getIncompatibleVersionChanges() {
        ArrayList<OperatorVersion> versions = new ArrayList<OperatorVersion>(Arrays.asList(super.getIncompatibleVersionChanges()));
        versions.add(PythonScriptingOperator.VERSION_HDF5_DATE_TIME_BUG);
        versions.add(PythonScriptingOperator.VERSION_ARROW_SERIALIZATION);
        return versions.toArray(new OperatorVersion[0]);
    }

    @Override
    public String getDefaultDeclaration() {
        return DECLARATION_TEMPLATE;
    }

    @Override
    public String getDefaultScript() {
        return SCRIPT_TEMPLATE;
    }

    @Override
    public String getType() {
        return TYPE;
    }

    public static Set<String> getStaticInputPorts() {
        return STATIC_INPUT_PORTS;
    }

    public static Set<String> getStaticOutputPorts() {
        return STATIC_OUTPUT_PORTS;
    }

    public void doWork() throws OperatorException {
        ExampleSet data = (ExampleSet)this.dataInput.getData(ExampleSet.class);
        try {
            this.reloadDeclaration();
            this.validateInputData(data);
            Model model = this.learn(data);
            this.modelOutput.deliver((IOObject)model);
            this.dataOutput.deliver((IOObject)data);
        }
        catch (IOException e) {
            throw new UserError((Operator)this, (Throwable)e, "python_scripting.loading_json_failed");
        }
    }

    public Model learn(ExampleSet data) throws OperatorException {
        boolean unsupervised = this.isUnsupervisedLearning();
        ExampleSet featureSet = this.prepareFeatureSet(data, unsupervised);
        ExampleSet labelSet = unsupervised ? null : this.prepareLabelSet(data);
        ExampleSet trainingSet = this.prepareTrainingSet(featureSet, labelSet, unsupervised);
        BufferedFileObject parameters = new BufferedFileObject(this.compileParametersAsJson(new PythonOperator.ParameterKeyValue[0]));
        ScriptRunner runner = this.createScriptRunner(unsupervised);
        List<IOObject> dynamicInputs = this.prepareDynamicInputs(runner);
        List<IOObject> inputs = this.prepareScriptInputs(featureSet, labelSet, dynamicInputs, parameters, unsupervised);
        List<IOObject> results = this.executeScript(runner, inputs);
        return this.wrapModel(trainingSet, results);
    }

    private void validateInputData(ExampleSet data) throws OperatorException {
        if (!this.isUnsupervisedLearning()) {
            Tools.isLabelled((ExampleSet)data);
        }
        Tools.isNonEmpty((ExampleSet)data);
        Tools.hasRegularAttributes((ExampleSet)data);
        CapabilityCheck check = new CapabilityCheck((CapabilityProvider)this, false);
        check.checkLearnerCapabilities((Operator)this, data);
    }

    private boolean isUnsupervisedLearning() {
        return this.supportsCapability(OperatorCapability.NO_LABEL);
    }

    private ExampleSet prepareFeatureSet(ExampleSet data, boolean unsupervised) {
        ExampleSet featureSet = (ExampleSet)data.clone();
        Attributes featureAttributes = featureSet.getAttributes();
        OperatorDeclaration declaration = this.getDeclaration();
        if (unsupervised) {
            if (declaration.isDropSpecial()) {
                featureAttributes.clearSpecial();
            }
        } else {
            Attribute labelAttribute = featureAttributes.getLabel();
            if (labelAttribute != null) {
                featureAttributes.remove(labelAttribute);
            }
            if (declaration.isDropSpecial()) {
                Iterator iterator = featureAttributes.allAttributeRoles();
                ArrayList<Attribute> attributesToRemove = new ArrayList<Attribute>();
                while (iterator.hasNext()) {
                    AttributeRole role = (AttributeRole)iterator.next();
                    if (!role.isSpecial() || "label".equals(role.getSpecialName())) continue;
                    attributesToRemove.add(role.getAttribute());
                }
                for (Attribute attribute : attributesToRemove) {
                    featureAttributes.remove(attribute);
                }
            }
        }
        return featureSet;
    }

    private ExampleSet prepareLabelSet(ExampleSet data) {
        ExampleSet labelSet = (ExampleSet)data.clone();
        Attributes labelAttributes = labelSet.getAttributes();
        Iterator iterator = labelAttributes.allAttributeRoles();
        ArrayList<Attribute> attributesToRemove = new ArrayList<Attribute>();
        while (iterator.hasNext()) {
            AttributeRole role = (AttributeRole)iterator.next();
            if ("label".equals(role.getSpecialName())) continue;
            attributesToRemove.add(role.getAttribute());
        }
        for (Attribute attribute : attributesToRemove) {
            labelAttributes.remove(attribute);
        }
        return labelSet;
    }

    private ExampleSet prepareTrainingSet(ExampleSet featureSet, ExampleSet labelSet, boolean unsupervised) {
        if (unsupervised) {
            return featureSet;
        }
        ExampleSet trainingSet = (ExampleSet)featureSet.clone();
        Attributes labelAttributes = labelSet.getAttributes();
        trainingSet.getAttributes().add(labelAttributes.getRole(labelAttributes.getLabel()));
        return trainingSet;
    }

    private ScriptRunner createScriptRunner(boolean unsupervised) throws UserError {
        boolean useArrowSerialization = this.getCompatibilityLevel().isAbove((VersionNumber)PythonScriptingOperator.VERSION_ARROW_SERIALIZATION);
        String trainingScript = this.getTrainingScript(unsupervised);
        return EnvironmentTools.getScriptRunner(this, trainingScript, null, false, useArrowSerialization);
    }

    private List<IOObject> prepareDynamicInputs(ScriptRunner runner) throws OperatorException {
        List<IOObject> dynamicInputs = this.checkInputTypes(runner).stream().skip(NUMBER_OF_STATIC_INPUT_PORTS).collect(Collectors.toList());
        this.checkExampleSet(dynamicInputs);
        return dynamicInputs;
    }

    private List<IOObject> prepareScriptInputs(ExampleSet featureSet, ExampleSet labelSet, List<IOObject> dynamicInputs, BufferedFileObject parameters, boolean unsupervised) {
        ArrayList<IOObject> inputs = new ArrayList<IOObject>();
        inputs.add((IOObject)featureSet);
        if (!unsupervised) {
            inputs.add((IOObject)labelSet);
        }
        inputs.addAll(dynamicInputs);
        inputs.add((IOObject)parameters);
        return inputs;
    }

    private List<IOObject> executeScript(ScriptRunner runner, List<IOObject> inputs) throws OperatorException {
        int nOutputs = this.getOutputPorts().getNumberOfPorts();
        try {
            List<IOObject> results = runner.run(inputs, nOutputs);
            if (results.size() != nOutputs) {
                throw new UserError((Operator)this, "python_scripting.mismatching_outputs", new Object[]{nOutputs, results.size()});
            }
            return results;
        }
        catch (Exception e) {
            this.checkForStop();
            throw new UserError((Operator)this, (Throwable)e, "python_scripting.python_runner_error");
        }
    }

    private PythonModel wrapModel(ExampleSet trainingSet, List<IOObject> results) throws OperatorException {
        try {
            PythonNativeObject model = (PythonNativeObject)results.get(0);
            String descriptionText = this.extractDescriptionText(results.get(1));
            byte[] modelBytes = Files.readAllBytes(model.getFile().toPath());
            PythonModel wrapper = new PythonModel(this.getDeclaration().getName(), descriptionText, this.getScript(), modelBytes, trainingSet);
            this.setWrapper(wrapper);
            this.deliverDynamicOutputs(results);
            return wrapper;
        }
        catch (IOException | ClassCastException e) {
            throw new UserError((Operator)this, (Throwable)e, "python_scripting.failed_to_wrap_model");
        }
    }

    private String extractDescriptionText(IOObject descriptionObj) {
        if (descriptionObj instanceof ExampleSet) {
            ExampleSet exampleSet = (ExampleSet)descriptionObj;
            return exampleSet.getAttributes().get("description").getMapping().mapIndex(0);
        }
        if (descriptionObj instanceof IOTable) {
            IOTable ioTable = (IOTable)descriptionObj;
            return ioTable.getTable().column("description").getDictionary().get(1);
        }
        throw new IllegalArgumentException("Unsupported type for description: " + descriptionObj.getClass());
    }

    private void deliverDynamicOutputs(List<IOObject> results) {
        OutputPorts outputs = this.getOutputPorts();
        for (int i = NUMBER_OF_STATIC_OUTPUT_PORTS; i < outputs.getNumberOfPorts(); ++i) {
            OutputPort port = (OutputPort)outputs.getPortByIndex(i);
            port.deliver(results.get(i));
        }
    }

    public boolean shouldEstimatePerformance() {
        return false;
    }

    public PerformanceVector getEstimatedPerformance() throws OperatorException {
        throw new UserError((Operator)this, 912, new Object[]{this.getName(), "estimation of performance not supported."});
    }

    public boolean shouldCalculateWeights() {
        return false;
    }

    public AttributeWeights getWeights(ExampleSet exampleSet) throws OperatorException {
        throw new UserError((Operator)this, 912, new Object[]{this.getName(), "computation of weights not supported."});
    }

    public boolean supportsCapability(OperatorCapability capability) {
        OperatorDeclaration declaration = this.getDeclaration();
        if (declaration != null && declaration.getCapabilities() != null && capability != null) {
            return declaration.getCapabilities().contains(capability.getDescription());
        }
        return false;
    }

    private String getTrainingScript(boolean unsupervised) throws UndefinedParameterError {
        return this.getScript() + (unsupervised ? UNSUPERVISED_TRAINING_ENTRYPOINT : SUPERVISED_TRAINING_ENTRYPOINT);
    }

    private void initializeOperator(boolean reloadDeclaration) {
        this.dataInput.addPrecondition((Precondition)new LearnerPrecondition((CapabilityProvider)this, this.dataInput));
        this.getTransformer().addRule((MDTransformationRule)new GeneratePredictionModelTransformationRule(this.dataInput, this.modelOutput, PythonModel.class));
        this.getTransformer().addRule((MDTransformationRule)new PassThroughRule(this.dataInput, this.dataOutput, false));
        if (reloadDeclaration) {
            this.tryReloadDeclaration();
        }
    }

    private static String loadEntrypointScript(String filePath) {
        return ConfigurationTools.loadTextFile(filePath);
    }

    private static String loadTemplate(String filePath) {
        return ConfigurationTools.loadTextFile(filePath);
    }
}

