package game.models;

import configuration.CfgTemplate;
import configuration.models.ConnectableModelConfig;
import configuration.models.ModelConfig;
import game.cSerialization.CCodeMainGenerator;
import game.cSerialization.CCodeUtils;
import game.cSerialization.XMLBuildUtils;
import game.evolution.treeEvolution.FitnessNode;
import game.evolution.treeEvolution.evolutionControl.EvolutionUtils;
import game.preprocessing.ModelLearnProcessing;
import game.preprocessing.ModelOutputProcessing;
import game.preprocessing.ModelQueryProcessing;
import game.preprocessing.PreprocessingAlgorithm;
import game.preprocessing.SimpleData;
import game.preprocessing.WeightBalancePreprocessing;
import game.utils.Utils;
import java.util.ArrayList;
import java.util.List;
import org.apache.log4j.Logger;
import org.ytoh.configurations.ui.SelectionSetModel;

/* loaded from: input_file:game/models/ConnectableModel.class */
public class ConnectableModel extends ModelLearnableBase {
    protected boolean[] selectedInputs;
    protected int activeInputs;
    protected double[] weights = null;
    protected ModelLearnable model;

    /* renamed from: preprocessing, reason: collision with root package name */
    protected List<PreprocessingAlgorithm> f2preprocessing;
    protected SelectionSetModel<String> enabledPreprocessing;

    public void init(ModelConfig modelConfig, List<PreprocessingAlgorithm> list) {
        this.maxLearningVectors = -1;
        this.activeInputs = 0;
        if (!(modelConfig instanceof ConnectableModelConfig)) {
            this.selectedInputs = new boolean[0];
            initModel(modelConfig);
            return;
        }
        ConnectableModelConfig connectableModelConfig = (ConnectableModelConfig) modelConfig;
        this.selectedInputs = connectableModelConfig.getSelectedInputs();
        for (int i = 0; i < this.selectedInputs.length; i++) {
            if (this.selectedInputs[i]) {
                this.activeInputs++;
            }
        }
        this.enabledPreprocessing = connectableModelConfig.getPreprocessingMethods();
        initPreprocessingMethods(this.enabledPreprocessing, list);
        initModel((CfgTemplate) connectableModelConfig.getNode(0));
    }

    @Override // game.models.ModelLearnableBase, game.models.ModelLearnable
    public void init(ModelConfig modelConfig) {
        init(modelConfig, null);
    }

    public void storeLearningVector(double[] dArr, double d, double d2) {
        if (this.weights == null) {
            this.weights = new double[this.maxLearningVectors];
        }
        this.weights[this.learning_vectors] = d2;
        storeLearningVector(dArr, d);
    }

    private void initModel(CfgTemplate cfgTemplate) {
        ModelConfig modelConfig = (ModelConfig) cfgTemplate;
        try {
            ModelLearnable modelLearnable = (ModelLearnable) cfgTemplate.getClassRef().newInstance();
            modelLearnable.init(modelConfig);
            this.model = modelLearnable;
        } catch (IllegalAccessException e) {
            logException(e);
        } catch (InstantiationException e2) {
            logException(e2);
        }
    }

    @Override // game.models.ModelLearnableBase, game.models.Model
    public ModelConfig getConfig() {
        ConnectableModelConfig connectableModelConfig = null;
        try {
            connectableModelConfig = (ConnectableModelConfig) getConfigClass().getConstructor(Integer.TYPE).newInstance(Integer.valueOf(this.selectedInputs.length));
            connectableModelConfig.setClassRef(getClass());
            boolean[] zArr = new boolean[this.selectedInputs.length];
            System.arraycopy(this.selectedInputs, 0, zArr, 0, this.selectedInputs.length);
            connectableModelConfig.setSelectedInputs(zArr);
            connectableModelConfig.setPreprocessingMethods(EvolutionUtils.cloneSelectionSet(this.enabledPreprocessing));
            connectableModelConfig.addNode((FitnessNode) this.model.getConfig());
        } catch (Exception e) {
            logException(e, getClass().getSimpleName());
        }
        return connectableModelConfig;
    }

    private void initPreprocessingMethods(SelectionSetModel<String> selectionSetModel, List<PreprocessingAlgorithm> list) {
        try {
            if (list == null) {
                initFromConfig(selectionSetModel);
            } else {
                initFromConfigAndReferences(selectionSetModel, list);
            }
        } catch (ClassNotFoundException e) {
            logException(e);
        } catch (IllegalAccessException e2) {
            logException(e2);
        } catch (InstantiationException e3) {
            logException(e3);
        }
    }

    private void initFromConfigAndReferences(SelectionSetModel<String> selectionSetModel, List<PreprocessingAlgorithm> list) throws ClassNotFoundException, InstantiationException, IllegalAccessException {
        String[] allElements = selectionSetModel.getAllElements();
        boolean[] stateOfElements = selectionSetModel.getStateOfElements();
        ArrayList arrayList = new ArrayList();
        Class[] clsArr = new Class[list.size()];
        for (int i = 0; i < list.size(); i++) {
            clsArr[i] = list.get(i).getClass();
        }
        for (int i2 = 0; i2 < allElements.length; i2++) {
            Class<?> cls = Class.forName("game.preprocessing." + allElements[i2]);
            int indexOf = Utils.indexOf(cls, clsArr);
            if (indexOf != -1) {
                arrayList.add(list.get(indexOf));
                this.enabledPreprocessing.enableElement(i2);
            } else if (stateOfElements[i2]) {
                arrayList.add((PreprocessingAlgorithm) cls.newInstance());
            }
        }
        this.f2preprocessing = arrayList;
    }

    private void initFromConfig(SelectionSetModel<String> selectionSetModel) throws ClassNotFoundException, InstantiationException, IllegalAccessException {
        String[] enabledElements = selectionSetModel.getEnabledElements(String.class);
        this.f2preprocessing = new ArrayList(enabledElements.length);
        for (String str : enabledElements) {
            this.f2preprocessing.add((PreprocessingAlgorithm) Class.forName("game.preprocessing." + str).newInstance());
        }
    }

    private void logException(Exception exc) {
        logException(exc, getConfig().toString());
    }

    private void logException(Exception exc, String str) {
        Logger logger = Logger.getLogger(getClass());
        StackTraceElement[] stackTrace = exc.getStackTrace();
        String str2 = "learning exception " + exc.toString();
        if (stackTrace != null && stackTrace.length > 0) {
            str2 = str2 + " in " + stackTrace[0].toString();
        }
        logger.error(str2 + " in config " + str);
    }

    @Override // game.models.ModelLearnableBase, game.models.ModelLearnable
    public void deleteLearningVectors() {
        super.deleteLearningVectors();
        this.model.deleteLearningVectors();
    }

    @Override // game.models.ModelLearnableBase, game.models.ModelLearnable
    public void resetLearningData() {
        super.resetLearningData();
        this.model.resetLearningData();
    }

    @Override // game.configuration.Configurable
    public Class getConfigClass() {
        return ConnectableModelConfig.class;
    }

    @Override // game.models.ModelLearnableBase, game.models.ModelLearnable
    public void storeLearningVector(double[] dArr, double d) {
        super.storeLearningVector(modifyInput(dArr), d);
    }

    @Override // game.models.ModelLearnable
    public void learn() {
        if (this.weights != null) {
            WeightBalancePreprocessing weightBalancePreprocessing = new WeightBalancePreprocessing();
            weightBalancePreprocessing.init(this.weights);
            this.f2preprocessing.add(0, weightBalancePreprocessing);
            this.weights = null;
        }
        performLearnProcessing();
        this.model.setMaxLearningVectors(this.maxLearningVectors);
        for (int i = 0; i < this.learning_vectors; i++) {
            this.model.storeLearningVector(this.inputVect[i], this.target[i]);
        }
        this.model.learn();
        postLearnActions();
    }

    private void performLearnProcessing() {
        for (int i = 0; i < this.f2preprocessing.size(); i++) {
            if (this.f2preprocessing.get(i) instanceof ModelLearnProcessing) {
                SimpleData learnProcessing = ((ModelLearnProcessing) this.f2preprocessing.get(i)).learnProcessing(this.inputVect, this.target);
                this.inputVect = learnProcessing.inputData;
                this.target = learnProcessing.outputData[0];
                this.learning_vectors = this.inputVect.length;
                this.maxLearningVectors = Math.max(this.maxLearningVectors, this.inputVect.length);
            }
        }
    }

    public void relearn() {
        learn();
    }

    @Override // game.models.Model
    public double getOutput(double[] dArr) {
        double[] modifyInput = modifyInput(dArr);
        if (modifyInput.length == dArr.length) {
            modifyInput = new double[dArr.length];
            System.arraycopy(dArr, 0, modifyInput, 0, dArr.length);
        }
        return performOutputProcessing(this.model.getOutput(performQueryProcessing(modifyInput)));
    }

    @Override // game.cSerialization.CSerialization
    public String toCCode(StringBuilder sb, StringBuilder sb2) {
        sb.append(CCodeUtils.getGlobalIncludes());
        CCodeMainGenerator.writeMainMethod(getModelCCode(sb, sb2), this.inputsNumber, 1, sb);
        return "main";
    }

    private String getModelCCode(StringBuilder sb, StringBuilder sb2) {
        XMLBuildUtils.outputXMLStart(sb2, this);
        String uniqueFunctionName = CCodeUtils.getUniqueFunctionName(getClass());
        String cCode = this.model.toCCode(sb, sb2);
        XMLBuildUtils.outputXMLEnd(sb2, this, uniqueFunctionName);
        ArrayList arrayList = new ArrayList(this.f2preprocessing.size());
        ArrayList arrayList2 = new ArrayList(this.f2preprocessing.size());
        for (int i = 0; i < this.f2preprocessing.size(); i++) {
            if (this.f2preprocessing.get(i) instanceof ModelOutputProcessing) {
                arrayList.add((ModelOutputProcessing) this.f2preprocessing.get(i));
            }
            if (this.f2preprocessing.get(i) instanceof ModelQueryProcessing) {
                arrayList2.add((ModelQueryProcessing) this.f2preprocessing.get(i));
            }
        }
        String[] strArr = new String[arrayList.size()];
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            strArr[i2] = ((ModelOutputProcessing) arrayList.get(i2)).outputToCCode(sb);
        }
        String[] strArr2 = new String[arrayList2.size()];
        for (int i3 = 0; i3 < arrayList2.size(); i3++) {
            strArr2[i3] = ((ModelQueryProcessing) arrayList2.get(i3)).queryToCCode(sb);
        }
        sb.append("#include \"").append(CCodeUtils.getRegressionModelPath()).append("ConnectableModel.h\"\n");
        sb.append("\ndouble ").append(uniqueFunctionName).append("(double input[").append(this.inputsNumber).append("]) {\n");
        sb.append("static double (*model)(const double*) = {&").append(cCode).append("};\n");
        CCodeUtils.getCFunctionsArray(strArr2, "static double* (*inputProcessingMethods[" + strArr2.length + "])(double*)", sb);
        CCodeUtils.getCFunctionsArray(strArr, "static double (*outputProcessingMethods[" + strArr.length + "])(const double)", sb);
        CCodeUtils.convertArray(this.selectedInputs, "enabledInputs", sb);
        sb.append("return connectableModelOutput<").append(this.activeInputs).append(",").append(this.inputsNumber).append(",").append(strArr2.length).append(",").append(strArr.length).append(">(input,model,enabledInputs,inputProcessingMethods,outputProcessingMethods);\n");
        sb.append("}\n");
        return uniqueFunctionName;
    }

    private double performOutputProcessing(double d) {
        double d2 = d;
        for (int i = 0; i < this.f2preprocessing.size(); i++) {
            if (this.f2preprocessing.get(i) instanceof ModelOutputProcessing) {
                d2 = ((ModelOutputProcessing) this.f2preprocessing.get(i)).outputProcessing(d2);
            }
        }
        return d2;
    }

    private double[] performQueryProcessing(double[] dArr) {
        for (int i = 0; i < this.f2preprocessing.size(); i++) {
            if (this.f2preprocessing.get(i) instanceof ModelQueryProcessing) {
                dArr = ((ModelQueryProcessing) this.f2preprocessing.get(i)).queryProcessing(dArr);
            }
        }
        return dArr;
    }

    protected double[] modifyInput(double[] dArr) {
        if (this.selectedInputs.length != dArr.length) {
            this.selectedInputs = new boolean[dArr.length];
            for (int i = 0; i < this.selectedInputs.length; i++) {
                this.selectedInputs[i] = true;
            }
            this.activeInputs = dArr.length;
        }
        if (this.activeInputs == dArr.length) {
            return dArr;
        }
        double[] dArr2 = new double[this.activeInputs];
        int i2 = 0;
        for (int i3 = 0; i3 < this.selectedInputs.length; i3++) {
            if (this.selectedInputs[i3]) {
                dArr2[i2] = dArr[i3];
                i2++;
            }
        }
        return dArr2;
    }

    public ModelLearnable getModel() {
        return this.model;
    }

    @Override // game.models.ModelLearnableBase, game.models.ModelLearnable, game.models.Model
    public int getInputsNumber() {
        return this.selectedInputs.length;
    }
}
