package com.rapidminer.operator.io.pmml;

import com.rapidminer.example.Attribute;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.learner.functions.neuralnet.ImprovedNeuralNetModel;
import com.rapidminer.operator.learner.functions.neuralnet.InnerNode;
import com.rapidminer.operator.learner.functions.neuralnet.InputNode;
import com.rapidminer.operator.learner.functions.neuralnet.Node;
import com.rapidminer.operator.learner.functions.neuralnet.OutputNode;
import com.rapidminer.tools.Tools;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import org.w3c.dom.Document;
import org.w3c.dom.Element;

/* loaded from: input_file:com/rapidminer/operator/io/pmml/NeuralNetModelPMMLWriter.class */
public class NeuralNetModelPMMLWriter extends AbstractPredictionModelPMMLWriter {
    private ImprovedNeuralNetModel model;
    private int numberOfHiddenLayers;

    public NeuralNetModelPMMLWriter(ImprovedNeuralNetModel improvedNeuralNetModel) {
        super(improvedNeuralNetModel);
        this.numberOfHiddenLayers = 0;
        this.model = improvedNeuralNetModel;
    }

    @Override // com.rapidminer.operator.io.pmml.AbstractPMMLModelWriter
    public Element createModelBody(Document document, PMMLVersion pMMLVersion) throws UserError {
        Element createElement = document.createElement("NeuralNetwork");
        createElement.setAttribute("modelName", this.model.getName());
        if (this.model.getTrainingHeader().getAttributes().getLabel().isNumerical()) {
            createElement.setAttribute("functionName", "regression");
        } else {
            createElement.setAttribute("functionName", "classification");
        }
        createElement.setAttribute("algorithmName", "NeuralNet");
        createElement.setAttribute("activationFunction", "logistic");
        createMiningSchema(document, createElement, this.model);
        createOutput(document, createElement, this.model);
        createTargetValues(document, createElement, this.model);
        createNeuralNetInputs(document, createElement, this.model);
        createNeuralNetLayers(document, createElement, this.model);
        createNeuralNetOutputs(document, createElement, this.model);
        createElement.setAttribute("numberOfLayers", this.numberOfHiddenLayers + "");
        return createElement;
    }

    private void createNeuralNetInputs(Document document, Element element, ImprovedNeuralNetModel improvedNeuralNetModel) {
        Element createElement = createElement(document, element, "NeuralInputs");
        InputNode[] inputNodes = improvedNeuralNetModel.getInputNodes();
        createElement.setAttribute("numberOfInputs", inputNodes.length + "");
        for (InputNode inputNode : inputNodes) {
            Attribute attribute = inputNode.getAttribute();
            Element createElement2 = createElement(document, createElement, "NeuralInput");
            createElement2.setAttribute("id", inputNode.getNodeName() + "");
            if (inputNode.isNormalize()) {
                Element createElement3 = createElement(document, createElement2, "DerivedField");
                createElement3.setAttribute("datatype", PMMLTranslation.getValueType(attribute));
                createElement3.setAttribute("optype", PMMLTranslation.getOpType(attribute));
                if (attribute.isNumerical()) {
                    Element createElement4 = createElement(document, createElement3, "NormContinuous");
                    createElement4.setAttribute("field", attribute.getName());
                    Element createElement5 = createElement(document, createElement4, "LinearNorm");
                    createElement5.setAttribute("orig", Tools.formatNumber(inputNode.getAttributeBase() - inputNode.getAttributeRange()));
                    createElement5.setAttribute("norm", "-1");
                    Element createElement6 = createElement(document, createElement4, "LinearNorm");
                    createElement6.setAttribute("orig", Tools.formatNumber(inputNode.getAttributeBase() + inputNode.getAttributeRange()));
                    createElement6.setAttribute("norm", "1");
                } else if (attribute.isNominal()) {
                    Element createElement7 = createElement(document, createElement3, "NormDiscrete");
                    createElement7.setAttribute("field", attribute.getName());
                    createElement7.setAttribute("value", inputNode.getCurrentValue() + "");
                }
            }
        }
    }

    private void createNeuralNetLayers(Document document, Element element, ImprovedNeuralNetModel improvedNeuralNetModel) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        ArrayList arrayList = new ArrayList();
        for (InnerNode innerNode : improvedNeuralNetModel.getInnerNodes()) {
            int layerIndex = innerNode.getLayerIndex();
            if (layerIndex != -2) {
                List list = (List) linkedHashMap.get(Integer.valueOf(layerIndex));
                if (list == null) {
                    list = new ArrayList();
                    linkedHashMap.put(Integer.valueOf(layerIndex), list);
                }
                list.add(innerNode);
            } else {
                arrayList.add(innerNode);
            }
        }
        Iterator it = linkedHashMap.keySet().iterator();
        while (it.hasNext()) {
            createNeuralNetLayerElements(document, element, (List) linkedHashMap.get((Integer) it.next()));
            this.numberOfHiddenLayers++;
        }
        createNeuralNetLayerElements(document, element, arrayList);
        this.numberOfHiddenLayers++;
    }

    private void createNeuralNetLayerElements(Document document, Element element, List<InnerNode> list) {
        Element createElement = createElement(document, element, "NeuralLayer");
        createElement.setAttribute("numberOfNeurons", list.size() + "");
        if (list.get(0).getActivationFunction().getTypeName().equals("Sigmoid")) {
            createElement.setAttribute("activationFunction", "logistic");
        }
        for (InnerNode innerNode : list) {
            Element createElement2 = createElement(document, createElement, "Neuron");
            if (innerNode.getLayerIndex() >= 0) {
                createElement2.setAttribute("id", innerNode.getLayerIndex() + "_" + innerNode.getNodeName() + "");
            } else {
                createElement2.setAttribute("id", innerNode.getNodeName() + "");
                createElement.setAttribute("normalizationMethod", "simplemax");
            }
            double[] weights = innerNode.getWeights();
            createElement2.setAttribute("bias", weights[0] + "");
            Node[] inputNodes = innerNode.getInputNodes();
            for (int i = 0; i < inputNodes.length; i++) {
                Element createElement3 = createElement(document, createElement2, "Con");
                if (inputNodes[i] instanceof InputNode) {
                    createElement3.setAttribute("from", inputNodes[i].getNodeName() + "");
                } else {
                    createElement3.setAttribute("from", inputNodes[i].getLayerIndex() + "_" + inputNodes[i].getNodeName() + "");
                }
                createElement3.setAttribute("weight", weights[i + 1] + "");
            }
        }
    }

    private void createNeuralNetOutputs(Document document, Element element, ImprovedNeuralNetModel improvedNeuralNetModel) throws UserError {
        Element createElement = createElement(document, element, "NeuralOutputs");
        OutputNode[] outputNodes = improvedNeuralNetModel.getOutputNodes();
        createElement.setAttribute("numberOfOutputs", outputNodes.length + "");
        for (OutputNode outputNode : outputNodes) {
            Attribute label = outputNode.getLabel();
            Element createElement2 = createElement(document, createElement, "NeuralOutput");
            if (label.isNominal()) {
                createElement2.setAttribute("outputNeuron", outputNode.getInputNodes()[0].getNodeName());
            } else {
                createElement2.setAttribute("outputNeuron", outputNode.getInputNodes()[0].getNodeName() + "");
            }
            Element createElement3 = createElement(document, createElement2, "DerivedField");
            createElement3.setAttribute("datatype", PMMLTranslation.getValueType(label));
            createElement3.setAttribute("optype", PMMLTranslation.getOpType(label));
            if (label.isNumerical()) {
                throw new RuntimeException("A numerical output is not supported for converting a NeuralNet Model to PMML.");
            }
            if (label.isNominal()) {
                Element createElement4 = createElement(document, createElement3, "NormDiscrete");
                createElement4.setAttribute("field", label.getName());
                createElement4.setAttribute("value", label.getMapping().mapIndex(outputNode.getClassIndex()));
            }
        }
    }

    @Override // com.rapidminer.operator.io.pmml.PMMLObjectWriter
    public Collection<String> checkCompatibility() {
        return null;
    }

    @Override // com.rapidminer.operator.io.pmml.AbstractPMMLObjectWriter
    protected String getMissingValueTreatment() {
        return "asValue";
    }

    @Override // com.rapidminer.operator.io.pmml.AbstractPMMLObjectWriter
    protected String getMissingValueReplacement(Attribute attribute) {
        return null;
    }
}
