package org.fabi.visualizations.evolution.scatterplot.modelling.evolution;

import configuration.CfgTemplate;
import configuration.classifiers.ClassifierConfig;
import configuration.classifiers.ConnectableClassifierConfig;
import configuration.models.ConnectableModelConfig;
import configuration.models.ModelConfig;
import game.classifiers.ConnectableClassifier;
import game.data.ArrayGameData;
import game.models.ConnectableModel;
import org.fabi.visualizations.scatter.sources.DataSource;
import org.fabi.visualizations.scatter.sources.ModelSource;

/* loaded from: input_file:org/fabi/visualizations/evolution/scatterplot/modelling/evolution/ModGenTools.class */
public class ModGenTools {

    /* loaded from: input_file:org/fabi/visualizations/evolution/scatterplot/modelling/evolution/ModGenTools$ExtConnectableClassifierConfig.class */
    protected static class ExtConnectableClassifierConfig extends ConnectableClassifierConfig {
        private static final long serialVersionUID = -7309627478458456208L;

        public ExtConnectableClassifierConfig(int i, ClassifierConfig classifierConfig) {
            super(i);
            this.config = classifierConfig;
        }
    }

    /* loaded from: input_file:org/fabi/visualizations/evolution/scatterplot/modelling/evolution/ModGenTools$ExtConnectableModelConfig.class */
    protected static class ExtConnectableModelConfig extends ConnectableModelConfig {
        private static final long serialVersionUID = -7309627478458456208L;

        public ExtConnectableModelConfig(int i, ModelConfig modelConfig) {
            super(i);
            this.config = modelConfig;
        }
    }

    public static ModelSource learnRegressionModel(CfgTemplate cfgTemplate, DataSource dataSource) {
        ConnectableModelConfig connectableModelConfig;
        ArrayGameData arrayGameData = new ArrayGameData(dataSource.getInputDataVectors(), dataSource.getOutputDataVectors());
        if (cfgTemplate instanceof ModelConfig) {
            connectableModelConfig = new ExtConnectableModelConfig(arrayGameData.getINumber(), (ModelConfig) cfgTemplate);
        } else {
            if (!(cfgTemplate instanceof ConnectableModelConfig)) {
                throw new IllegalArgumentException(String.valueOf(cfgTemplate.getClass().getSimpleName()) + " not allowed.");
            }
            connectableModelConfig = (ConnectableModelConfig) cfgTemplate;
        }
        ConnectableModel connectableModel = new ConnectableModel();
        connectableModel.init(connectableModelConfig);
        connectableModel.setMaxLearningVectors(arrayGameData.getInstanceNumber());
        double[][] inputVectors = arrayGameData.getInputVectors();
        double[][] outputAttrs = arrayGameData.getOutputAttrs();
        for (int i = 0; i < arrayGameData.getInstanceNumber(); i++) {
            connectableModel.storeLearningVector(inputVectors[i], outputAttrs[i][0]);
        }
        connectableModel.learn();
        return new GameRegressionModelSource(connectableModel);
    }

    public static ModelSource learnClassifier(CfgTemplate cfgTemplate, DataSource dataSource) {
        ConnectableClassifierConfig connectableClassifierConfig;
        ArrayGameData arrayGameData = new ArrayGameData(dataSource.getInputDataVectors(), dataSource.getOutputDataVectors());
        if (cfgTemplate instanceof ClassifierConfig) {
            connectableClassifierConfig = new ExtConnectableClassifierConfig(arrayGameData.getINumber(), (ClassifierConfig) cfgTemplate);
        } else {
            if (!(cfgTemplate instanceof ConnectableClassifierConfig)) {
                throw new IllegalArgumentException(String.valueOf(cfgTemplate.getClass().getSimpleName()) + " not allowed.");
            }
            connectableClassifierConfig = (ConnectableClassifierConfig) cfgTemplate;
        }
        ConnectableClassifier connectableClassifier = new ConnectableClassifier();
        connectableClassifier.init(connectableClassifierConfig);
        connectableClassifier.setMaxLearningVectors(arrayGameData.getInstanceNumber());
        double[][] inputVectors = arrayGameData.getInputVectors();
        double[][] outputAttrs = arrayGameData.getOutputAttrs();
        for (int i = 0; i < arrayGameData.getInstanceNumber(); i++) {
            connectableClassifier.storeLearningVector(inputVectors[i], outputAttrs[i]);
        }
        connectableClassifier.learn();
        return new GameClassifierModelSource(connectableClassifier);
    }
}
