package game.models.single.rapidMiner;

import com.rapidminer.operator.OperatorCreationException;
import com.rapidminer.operator.learner.functions.neuralnet.ActivationFunction;
import com.rapidminer.operator.learner.functions.neuralnet.ImprovedNeuralNetLearner;
import com.rapidminer.operator.learner.functions.neuralnet.InnerNode;
import com.rapidminer.operator.learner.functions.neuralnet.InputNode;
import com.rapidminer.operator.learner.functions.neuralnet.LinearFunction;
import com.rapidminer.operator.learner.functions.neuralnet.Node;
import com.rapidminer.operator.learner.functions.neuralnet.OutputNode;
import com.rapidminer.tools.OperatorService;
import configuration.models.ModelConfig;
import configuration.models.single.rapidMiner.RapidNeuralNetConfig;
import game.cSerialization.CCodeUtils;
import game.cSerialization.XMLBuildUtils;
import java.util.ArrayList;
import org.apache.commons.configuration.tree.DefaultExpressionEngine;

/* loaded from: input_file:game/models/single/rapidMiner/RapidNeuralNetModel.class */
public class RapidNeuralNetModel extends RapidMinerModel {
    protected int firstLayerNeurons;
    protected int secondLayerNeurons;
    protected double errorEpsilon;
    protected double learningRate;
    protected double momentum;

    @Override // game.models.ModelLearnableBase, game.models.ModelLearnable
    public void init(ModelConfig modelConfig) {
        super.init(modelConfig);
        RapidNeuralNetConfig rapidNeuralNetConfig = (RapidNeuralNetConfig) modelConfig;
        this.firstLayerNeurons = rapidNeuralNetConfig.getFirstLayerNeurons();
        this.secondLayerNeurons = rapidNeuralNetConfig.getSecondLayerNeurons();
        this.errorEpsilon = rapidNeuralNetConfig.getErrorEpsilon();
        this.learningRate = rapidNeuralNetConfig.getLearningRate();
        this.momentum = rapidNeuralNetConfig.getMomentum();
        try {
            this.learner = OperatorService.createOperator(ImprovedNeuralNetLearner.class);
            ImprovedNeuralNetLearner improvedNeuralNetLearner = this.learner;
            improvedNeuralNetLearner.setParameter("error_epsilon", Double.toString(this.errorEpsilon));
            improvedNeuralNetLearner.setParameter("learning_rate", Double.toString(this.learningRate));
            improvedNeuralNetLearner.setParameter("momentum", Double.toString(this.momentum));
            ArrayList arrayList = new ArrayList();
            arrayList.add(new String[]{"layer1", Integer.toString(this.firstLayerNeurons)});
            if (this.secondLayerNeurons != 0) {
                arrayList.add(new String[]{"layer2", Integer.toString(this.secondLayerNeurons)});
            }
            improvedNeuralNetLearner.setListParameter("hidden_layers", arrayList);
            improvedNeuralNetLearner.setParameter("training_cycles", "100");
        } catch (OperatorCreationException e) {
            System.err.println("Cannot create operator:" + e.getMessage());
        }
    }

    @Override // game.models.ModelLearnableBase, game.models.Model
    public ModelConfig getConfig() {
        RapidNeuralNetConfig rapidNeuralNetConfig = (RapidNeuralNetConfig) super.getConfig();
        rapidNeuralNetConfig.setErrorEpsilon(this.errorEpsilon);
        rapidNeuralNetConfig.setFirstLayerNeurons(this.firstLayerNeurons);
        rapidNeuralNetConfig.setSecondLayerNeurons(this.secondLayerNeurons);
        rapidNeuralNetConfig.setLearningRate(this.learningRate);
        rapidNeuralNetConfig.setMomentum(this.momentum);
        return rapidNeuralNetConfig;
    }

    @Override // game.configuration.Configurable
    public Class getConfigClass() {
        return RapidNeuralNetConfig.class;
    }

    @Override // game.models.single.rapidMiner.RapidMinerModel, game.cSerialization.CSerialization
    public String toCCode(StringBuilder sb, StringBuilder sb2) {
        sb.append("#include \"").append("NeuralNet.h\"\n");
        String uniqueFunctionName = CCodeUtils.getUniqueFunctionName(getClass());
        CCodeUtils.getCRegressionHeader(uniqueFunctionName, this.inputsNumber, sb);
        sb.append("return ");
        neuralNetToString(sb);
        sb.append(";\n");
        sb.append("}\n");
        XMLBuildUtils.outputXML(sb2, this, uniqueFunctionName);
        return uniqueFunctionName;
    }

    private void neuralNetToString(StringBuilder sb) {
        nodeToEquation(this.rapidModel.getOutputNodes()[0], sb);
    }

    private void nodeToEquation(Node node, StringBuilder sb) {
        InputNode[] inputNodes = node.getInputNodes();
        for (int i = 0; i < inputNodes.length; i++) {
            if (inputNodes[i] instanceof InputNode) {
                sb.append(node.getWeight(i));
                InputNode inputNode = inputNodes[i];
                if (!inputNode.isNormalize()) {
                    sb.append("*input[").append(i).append(DefaultExpressionEngine.DEFAULT_ATTRIBUTE_END);
                } else if (inputNode.getAttributeRange() != 0.0d) {
                    sb.append("*(input[").append(i).append("]-").append(inputNode.getAttributeBase()).append(")/").append(inputNode.getAttributeRange());
                } else {
                    sb.append("*(input[").append(i).append("]-").append(inputNode.getAttributeBase()).append(DefaultExpressionEngine.DEFAULT_INDEX_END);
                }
                sb.append("+");
            } else {
                if (node instanceof OutputNode) {
                    sb.append(((OutputNode) node).getLabelRange()).append("*");
                } else {
                    sb.append(node.getWeight(i)).append("*");
                }
                ActivationFunction activationFunction = ((InnerNode) inputNodes[i]).getActivationFunction();
                if (!(activationFunction instanceof LinearFunction)) {
                    sb.append(activationFunction.getTypeName().toLowerCase()).append("Neural");
                }
                sb.append(DefaultExpressionEngine.DEFAULT_INDEX_START);
                nodeToEquation(inputNodes[i], sb);
                sb.append(")+");
            }
        }
        if (node instanceof OutputNode) {
            sb.append(((OutputNode) node).getLabelBase());
        } else {
            sb.append(node.getWeight(-1));
        }
    }
}
