package com.rapidminer.extension.mlflow.operator.modelregistry;

import com.rapidminer.adaption.belt.IOTable;
import com.rapidminer.belt.column.ColumnType;
import com.rapidminer.connection.ConnectionInformationContainerIOObject;
import com.rapidminer.extension.mlflow.utility.ArtifactHandler;
import com.rapidminer.extension.mlflow.utility.Constants;
import com.rapidminer.extension.mlflow.utility.TableConverter;
import com.rapidminer.extension.mlflow.utility.responses.FileInfo;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.IOTablePredictionModel;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.table.ColumnInfoBuilder;
import com.rapidminer.operator.ports.metadata.table.TableMetaDataBuilder;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeRepositoryLocation;
import com.rapidminer.parameter.ParameterTypeString;
import com.rapidminer.parameter.conditions.BooleanParameterCondition;
import com.rapidminer.repository.RepositoryException;
import com.rapidminer.repository.RepositoryLocation;
import com.rapidminer.repository.versioned.BasicFolder;
import com.rapidminer.repository.versioned.JsonIOObjectEntry;
import com.rapidminer.repository.versioned.JsonStorableIOObject;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.List;
import okhttp3.HttpUrl;
import org.mlflow.api.proto.ModelRegistry;
import org.mlflow.tracking.MlflowClient;

/* loaded from: input_file:com/rapidminer/extension/mlflow/operator/modelregistry/RetrieveModelOperator.class */
public class RetrieveModelOperator extends Operator {
    public static final String PARAMETER_MODEL_NAME = "model_name";
    public static final String PARAMETER_REPOSITORY_ENTRY = "folder";
    public static final String PARAMETER_STAGE = "stage";
    public static final String PARAMETER_DOWNLOAD_ARTIFACTS = "download_artifacts";
    InputPort conInput;
    OutputPort modOutput;
    OutputPort exaOutput;
    OutputPort conOutput;

    public RetrieveModelOperator(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.conInput = getInputPorts().createPort("con", ConnectionInformationContainerIOObject.class);
        this.modOutput = getOutputPorts().createPort("mod");
        this.exaOutput = getOutputPorts().createPort("information");
        this.conOutput = getOutputPorts().createPort("con");
        getTransformer().addPassThroughRule(this.conInput, this.conOutput);
        getTransformer().addGenerationRule(this.modOutput, IOTablePredictionModel.class);
        getTransformer().addRule(() -> {
            TableMetaDataBuilder tableMetaDataBuilder = new TableMetaDataBuilder(0);
            for (int i = 0; i < TableConverter.registeredModelColumnNames.size(); i++) {
                tableMetaDataBuilder.add(TableConverter.registeredModelColumnNames.get(i), new ColumnInfoBuilder(ColumnType.forId(TableConverter.registeredModelColumnTypes.get(i))).build());
            }
            this.exaOutput.deliverMD(tableMetaDataBuilder.build());
        });
    }

    public void doWork() throws UserError {
        ConnectionInformationContainerIOObject connectionInformationContainerIOObject = (ConnectionInformationContainerIOObject) this.conInput.getData(ConnectionInformationContainerIOObject.class);
        this.conOutput.deliver(connectionInformationContainerIOObject);
        try {
            MlflowClient mlflowClient = new MlflowClient(connectionInformationContainerIOObject.getConnectionInformation().getConfiguration().getParameter("mlflow." + "uri").getValue());
            try {
                ModelRegistry.RegisteredModel registeredModel = mlflowClient.getRegisteredModel(getParameterAsString("model_name"));
                ModelRegistry.ModelVersion modelVersion = null;
                for (ModelRegistry.ModelVersion modelVersion2 : registeredModel.getLatestVersionsList()) {
                    if (modelVersion2.hasCurrentStage() && modelVersion2.getCurrentStage().equals(getParameterAsString("stage"))) {
                        modelVersion = modelVersion2;
                    }
                }
                if (modelVersion == null) {
                    throw new OperatorException("Cannot find model in stage: " + getParameterAsString("stage"));
                }
                String runId = modelVersion.getRunId();
                String experimentId = mlflowClient.getRun(runId).getInfo().getExperimentId();
                if (getParameterAsBoolean(PARAMETER_DOWNLOAD_ARTIFACTS)) {
                    RepositoryLocation parameterAsRepositoryLocationFolder = getParameterAsRepositoryLocationFolder(PARAMETER_REPOSITORY_ENTRY);
                    parameterAsRepositoryLocationFolder.createFoldersRecursively();
                    BasicFolder locateFolder = parameterAsRepositoryLocationFolder.locateFolder();
                    recursivelyDownloadFolders(connectionInformationContainerIOObject, experimentId, runId, HttpUrl.FRAGMENT_ENCODE_SET, Paths.get(locateFolder.getRepositoryAdapter().getRoot().toAbsolutePath().toString(), locateFolder.getFsFolder().getPath()).toAbsolutePath().toString());
                    locateFolder.getRepositoryAdapter().refresh();
                }
                if (this.modOutput.isConnected()) {
                    JsonStorableIOObject readIOObject = JsonIOObjectEntry.readIOObject(new ArtifactHandler(experimentId, runId).getArtifactAsString(connectionInformationContainerIOObject, runId, "model.rmmodel"));
                    this.exaOutput.deliver(new IOTable(TableConverter.registeredModelToTable(registeredModel, getParameterAsString("stage"))));
                    this.modOutput.deliver(readIOObject);
                }
                mlflowClient.close();
            } catch (Throwable th) {
                try {
                    mlflowClient.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        } catch (RepositoryException e) {
            throw new RuntimeException((Throwable) e);
        } catch (IOException e2) {
            throw new RuntimeException(e2);
        } catch (OperatorException e3) {
            throw new RuntimeException((Throwable) e3);
        }
    }

    private void recursivelyDownloadFolders(ConnectionInformationContainerIOObject connectionInformationContainerIOObject, String str, String str2, String str3, String str4) throws IOException, OperatorException {
        ArtifactHandler artifactHandler = new ArtifactHandler(str, str2);
        for (FileInfo fileInfo : artifactHandler.getArtifactList(connectionInformationContainerIOObject, str2, str3)) {
            if (fileInfo.getIs_dir().booleanValue()) {
                new File(str4 + "/" + URLDecoder.decode(fileInfo.getPath(), StandardCharsets.UTF_8)).mkdirs();
                recursivelyDownloadFolders(connectionInformationContainerIOObject, str, str2, fileInfo.getPath(), str4);
            } else {
                FileWriter fileWriter = new FileWriter(str4 + "/" + URLDecoder.decode(fileInfo.getPath(), StandardCharsets.UTF_8));
                try {
                    fileWriter.write(artifactHandler.getArtifactAsString(connectionInformationContainerIOObject, str2, fileInfo.getPath()));
                    fileWriter.close();
                } catch (Throwable th) {
                    try {
                        fileWriter.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                    throw th;
                }
            }
        }
    }

    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeString("model_name", "model name", false));
        parameterTypes.add(new ParameterTypeCategory("stage", "Transition model_version to new stage.", Constants.STAGES, 0));
        parameterTypes.add(new ParameterTypeBoolean(PARAMETER_DOWNLOAD_ARTIFACTS, "Wether or not to also download the artifacts", false));
        ParameterTypeRepositoryLocation parameterTypeRepositoryLocation = new ParameterTypeRepositoryLocation(PARAMETER_REPOSITORY_ENTRY, "Repository entry.", false, true, false, false, true, true);
        parameterTypeRepositoryLocation.registerDependencyCondition(new BooleanParameterCondition(this, PARAMETER_DOWNLOAD_ARTIFACTS, true, true));
        parameterTypeRepositoryLocation.setExpert(false);
        parameterTypeRepositoryLocation.setPrimary(true);
        parameterTypes.add(parameterTypeRepositoryLocation);
        return parameterTypes;
    }
}
