package game.evolution.treeEvolution.run;

import configuration.ConfigurationFactory;
import configuration.classifiers.ClassifierConfig;
import configuration.evolution.MainConfig;
import configuration.models.ModelConfig;
import game.classifiers.ConnectableClassifier;
import game.data.AbstractGameData;
import game.data.ArrayGameData;
import game.data.MiningType;
import game.evolution.treeEvolution.FitnessNode;
import game.evolution.treeEvolution.context.MultiCVClassifierContext;
import game.evolution.treeEvolution.context.MultiCVModelContext;
import game.evolution.treeEvolution.context.evaluators.RMSEModelEvaluator;
import game.evolution.treeEvolution.evolutionControl.EvolutionUtils;
import game.models.ConnectableModel;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import org.apache.commons.configuration.tree.DefaultExpressionEngine;
import org.apache.log4j.Logger;
import org.apache.log4j.PropertyConfigurator;

/* loaded from: input_file:game/evolution/treeEvolution/run/LabelMaker.class */
public class LabelMaker {
    private static Logger log;
    private static List<double[]> finalLabels;
    private static List<Double> finalTargetSet;
    private static List<Double> finalIdSet;

    public static void main(String[] strArr) throws IOException {
        MainConfig.setFileOutputPath("./evolution/labels/");
        CommandParse parseInput = parseInput(strArr);
        String flagValue = parseInput.getFlagValue("-template") != null ? parseInput.getFlagValue("-template") : "./evolution/rsj_973_1353508765275_1.txt";
        String flagValue2 = parseInput.getFlagValue("-data") != null ? parseInput.getFlagValue("-data") : "./data/rsj.txt";
        String flagValue3 = parseInput.getFlagValue("-testData") != null ? parseInput.getFlagValue("-testData") : "./data/rsjTest.txt";
        if (parseInput.getFlagValue("-outputPath") != null) {
            MainConfig.setFileOutputPath(parseInput.getFlagValue("-outputPath"));
        }
        String flagValue4 = parseInput.getFlagValue("-log") != null ? parseInput.getFlagValue("-log") : "info";
        int parseInt = parseInput.getFlagValue("-folds") != null ? Integer.parseInt(parseInput.getFlagValue("-folds")) : 5;
        int parseInt2 = parseInput.getFlagValue("-idIndex") != null ? Integer.parseInt(parseInput.getFlagValue("-idIndex")) : -1;
        log = setupLogger(flagValue4);
        AbstractGameData readDataFromFile = EvolutionUtils.readDataFromFile(flagValue2);
        FitnessNode fitnessNode = (FitnessNode) ConfigurationFactory.getConfiguration(flagValue);
        log.info("data: " + flagValue2);
        log.info("config: " + fitnessNode.toString());
        double[] dArr = new double[readDataFromFile.getInstanceNumber()];
        if (parseInt2 == -1) {
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = i;
            }
        } else {
            double[][] inputVectors = readDataFromFile.getInputVectors();
            double[][] outputAttrs = readDataFromFile.getOutputAttrs();
            double[][] dArr2 = new double[inputVectors.length][inputVectors[0].length - 1];
            for (int i2 = 0; i2 < inputVectors.length; i2++) {
                int i3 = 0;
                for (int i4 = 0; i4 < inputVectors[i2].length; i4++) {
                    if (i4 == parseInt2) {
                        dArr[i2] = inputVectors[i2][i4];
                    } else {
                        dArr2[i2][i3] = inputVectors[i2][i4];
                        i3++;
                    }
                }
            }
            readDataFromFile = new ArrayGameData(dArr2, outputAttrs);
        }
        AbstractGameData abstractGameData = null;
        if (flagValue3 != null) {
            abstractGameData = EvolutionUtils.readDataFromFile(flagValue3);
            log.info("test data: " + flagValue3);
            dArr = new double[abstractGameData.getInstanceNumber()];
            if (parseInt2 == -1) {
                for (int i5 = 0; i5 < dArr.length; i5++) {
                    dArr[i5] = i5;
                }
            } else {
                double[][] inputVectors2 = abstractGameData.getInputVectors();
                double[][] outputAttrs2 = abstractGameData.getOutputAttrs();
                double[][] dArr3 = new double[inputVectors2.length][inputVectors2[0].length - 1];
                for (int i6 = 0; i6 < inputVectors2.length; i6++) {
                    int i7 = 0;
                    for (int i8 = 0; i8 < inputVectors2[i6].length; i8++) {
                        if (i8 == parseInt2) {
                            dArr[i6] = inputVectors2[i6][i8];
                        } else {
                            dArr3[i6][i7] = inputVectors2[i6][i8];
                            i7++;
                        }
                    }
                }
                abstractGameData = new ArrayGameData(dArr3, outputAttrs2);
            }
        }
        double[] target = flagValue3 == null ? getTarget(readDataFromFile) : getTarget(abstractGameData);
        finalLabels = new ArrayList(readDataFromFile.getInstanceNumber());
        finalTargetSet = new ArrayList(readDataFromFile.getInstanceNumber());
        finalIdSet = new ArrayList(readDataFromFile.getInstanceNumber());
        if (flagValue3 == null) {
            labelLearnDataset(readDataFromFile, fitnessNode, parseInt, dArr, target);
        } else {
            labelTestDataset(readDataFromFile, abstractGameData, fitnessNode, dArr, target);
        }
        saveResultsToFile(flagValue, flagValue2);
    }

    private static double[] getTarget(AbstractGameData abstractGameData) {
        double[] dArr;
        if (abstractGameData.getOutputAttrs()[0].length == 0) {
            return new double[abstractGameData.getInstanceNumber()];
        }
        if (abstractGameData.getDataType() == MiningType.CLASSIFICATION) {
            dArr = new double[abstractGameData.getInstanceNumber()];
            int[] convertOutputData = EvolutionUtils.convertOutputData(abstractGameData.getOutputAttrs());
            for (int i = 0; i < convertOutputData.length; i++) {
                dArr[i] = convertOutputData[i];
            }
        } else {
            dArr = new double[abstractGameData.getInstanceNumber()];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr[i2] = abstractGameData.getOutputAttrs()[i2][0];
            }
        }
        return dArr;
    }

    private static void labelTestDataset(AbstractGameData abstractGameData, AbstractGameData abstractGameData2, FitnessNode fitnessNode, double[] dArr, double[] dArr2) {
        if (abstractGameData.getDataType() == MiningType.CLASSIFICATION) {
            double[][] inputVectors = abstractGameData.getInputVectors();
            double[][] outputAttrs = abstractGameData.getOutputAttrs();
            double[][] inputVectors2 = abstractGameData2.getInputVectors();
            ConnectableClassifier connectableClassifier = new ConnectableClassifier();
            connectableClassifier.init((ClassifierConfig) fitnessNode);
            connectableClassifier.setMaxLearningVectors(inputVectors.length);
            for (int i = 0; i < inputVectors.length; i++) {
                connectableClassifier.storeLearningVector(inputVectors[i], outputAttrs[i]);
            }
            log.info("learning");
            connectableClassifier.learn();
            for (int i2 = 0; i2 < inputVectors2.length; i2++) {
                finalLabels.add(connectableClassifier.getOutputProbabilities(inputVectors2[i2]));
                finalIdSet.add(Double.valueOf(dArr[i2]));
            }
            if (dArr2.length > 0) {
                for (int i3 = 0; i3 < inputVectors2.length; i3++) {
                    finalTargetSet.add(Double.valueOf(dArr2[i3]));
                }
                return;
            }
            return;
        }
        double[][] inputVectors3 = abstractGameData.getInputVectors();
        double[][] outputAttrs2 = abstractGameData.getOutputAttrs();
        double[] instanceWeights = abstractGameData.getInstanceWeights();
        double[][] inputVectors4 = abstractGameData2.getInputVectors();
        ConnectableModel connectableModel = new ConnectableModel();
        connectableModel.init((ModelConfig) fitnessNode);
        connectableModel.setMaxLearningVectors(inputVectors3.length);
        if (instanceWeights == null) {
            for (int i4 = 0; i4 < inputVectors3.length; i4++) {
                connectableModel.storeLearningVector(inputVectors3[i4], outputAttrs2[i4][0]);
            }
        } else {
            for (int i5 = 0; i5 < inputVectors3.length; i5++) {
                connectableModel.storeLearningVector(inputVectors3[i5], outputAttrs2[i5][0], instanceWeights[i5]);
            }
        }
        log.info("learning");
        connectableModel.learn();
        for (int i6 = 0; i6 < inputVectors4.length; i6++) {
            finalLabels.add(new double[]{connectableModel.getOutput(inputVectors4[i6])});
            finalIdSet.add(Double.valueOf(dArr[i6]));
        }
        if (dArr2.length > 0) {
            for (int i7 = 0; i7 < inputVectors4.length; i7++) {
                finalTargetSet.add(Double.valueOf(dArr2[i7]));
            }
            RMSEModelEvaluator rMSEModelEvaluator = new RMSEModelEvaluator();
            int[] iArr = new int[abstractGameData2.getInstanceNumber()];
            for (int i8 = 0; i8 < iArr.length; i8++) {
                iArr[i8] = i8;
            }
            log.info("Fitness on test: " + (-rMSEModelEvaluator.performTestOnData(connectableModel, (ModelConfig) fitnessNode, iArr, abstractGameData2)));
        }
    }

    private static void saveResultsToFile(String str, String str2) {
        try {
            String substring = str.substring(str.lastIndexOf("/") + 1);
            String substring2 = str2.substring(str2.lastIndexOf("/") + 1);
            String str3 = MainConfig.getFileOutputPath() + "labels_" + substring.substring(0, substring.lastIndexOf(DefaultExpressionEngine.DEFAULT_PROPERTY_DELIMITER)) + "_" + substring2.substring(0, substring2.lastIndexOf(DefaultExpressionEngine.DEFAULT_PROPERTY_DELIMITER)) + ".txt";
            log.info("saving " + str3);
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(str3));
            bufferedWriter.write("ID");
            for (int i = 0; i < finalLabels.get(0).length; i++) {
                bufferedWriter.write(";ModelResponse" + i);
            }
            if (finalTargetSet.size() > 0) {
                bufferedWriter.write(";Target");
            }
            bufferedWriter.newLine();
            for (int i2 = 0; i2 < finalIdSet.size(); i2++) {
                bufferedWriter.write(Double.toString(finalIdSet.get(i2).doubleValue()));
                for (int i3 = 0; i3 < finalLabels.get(0).length; i3++) {
                    bufferedWriter.write(";" + finalLabels.get(i2)[i3]);
                }
                if (finalTargetSet.size() > 0) {
                    bufferedWriter.write(";" + finalTargetSet.get(i2));
                }
                bufferedWriter.newLine();
            }
            bufferedWriter.flush();
            bufferedWriter.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private static void labelLearnDataset(AbstractGameData abstractGameData, FitnessNode fitnessNode, int i, double[] dArr, double[] dArr2) {
        int[][] foldIndexes = getFoldIndexes(abstractGameData, i);
        if (abstractGameData.getDataType() == MiningType.CLASSIFICATION) {
            for (int i2 = 0; i2 < foldIndexes.length; i2++) {
                labelClassificationSet(abstractGameData, (ClassifierConfig) fitnessNode, foldIndexes, i2, dArr, dArr2);
            }
            return;
        }
        for (int i3 = 0; i3 < foldIndexes.length; i3++) {
            labelRegressionSet(abstractGameData, (ModelConfig) fitnessNode, foldIndexes, i3, dArr, dArr2);
        }
    }

    private static void labelClassificationSet(AbstractGameData abstractGameData, ClassifierConfig classifierConfig, int[][] iArr, int i, double[] dArr, double[] dArr2) {
        int i2 = 0;
        for (int i3 = 0; i3 < iArr.length; i3++) {
            if (i3 != i) {
                i2 += iArr[i3].length;
            }
        }
        ConnectableClassifier connectableClassifier = new ConnectableClassifier();
        connectableClassifier.init(classifierConfig);
        connectableClassifier.setMaxLearningVectors(i2);
        for (int i4 = 0; i4 < iArr.length; i4++) {
            if (i4 != i) {
                for (int i5 = 0; i5 < iArr[i4].length; i5++) {
                    connectableClassifier.storeLearningVector(abstractGameData.getInputVector(iArr[i4][i5]), abstractGameData.getOutputAttributes(iArr[i4][i5]));
                }
            }
        }
        log.info("learning " + i);
        connectableClassifier.learn();
        for (int i6 = 0; i6 < iArr[i].length; i6++) {
            finalLabels.add(connectableClassifier.getOutputProbabilities(abstractGameData.getInputVector(iArr[i][i6])));
            finalIdSet.add(Double.valueOf(dArr[iArr[i][i6]]));
            finalTargetSet.add(Double.valueOf(dArr2[iArr[i][i6]]));
        }
        log.info("output to index;" + finalLabels.size());
    }

    private static void labelRegressionSet(AbstractGameData abstractGameData, ModelConfig modelConfig, int[][] iArr, int i, double[] dArr, double[] dArr2) {
        int i2 = 0;
        for (int i3 = 0; i3 < iArr.length; i3++) {
            if (i3 != i) {
                i2 += iArr[i3].length;
            }
        }
        ConnectableModel connectableModel = new ConnectableModel();
        connectableModel.init(modelConfig);
        connectableModel.setMaxLearningVectors(i2);
        for (int i4 = 0; i4 < iArr.length; i4++) {
            if (i4 != i) {
                for (int i5 = 0; i5 < iArr[i4].length; i5++) {
                    connectableModel.storeLearningVector(abstractGameData.getInputVector(iArr[i4][i5]), abstractGameData.getOutputAttributes(iArr[i4][i5])[0]);
                }
            }
        }
        log.info("learning " + i);
        connectableModel.learn();
        for (int i6 = 0; i6 < iArr[i].length; i6++) {
            finalLabels.add(new double[]{connectableModel.getOutput(abstractGameData.getInputVector(iArr[i][i6]))});
            finalIdSet.add(Double.valueOf(dArr[iArr[i][i6]]));
            finalTargetSet.add(Double.valueOf(dArr2[iArr[i][i6]]));
        }
        log.info("output to index;" + finalLabels.size());
    }

    private static int[][] getFoldIndexes(AbstractGameData abstractGameData, int i) {
        if (abstractGameData.getDataType() == MiningType.CLASSIFICATION) {
            MultiCVClassifierContext multiCVClassifierContext = new MultiCVClassifierContext();
            multiCVClassifierContext.setModelsBeforeCacheUse(i);
            multiCVClassifierContext.setTestDataPercent(0.0d);
            multiCVClassifierContext.setValidDataPercent(1.0d / i);
            multiCVClassifierContext.init(abstractGameData);
            return multiCVClassifierContext.getFoldsIndexes()[0];
        }
        if (abstractGameData.getDataType() != MiningType.REGRESSION) {
            log.error("unspecified mining type.");
            return new int[0][0];
        }
        MultiCVModelContext multiCVModelContext = new MultiCVModelContext();
        multiCVModelContext.setModelsBeforeCacheUse(i);
        multiCVModelContext.setTestDataPercent(0.0d);
        multiCVModelContext.setValidDataPercent(1.0d / i);
        multiCVModelContext.init(abstractGameData);
        return multiCVModelContext.getFoldsIndexes()[0];
    }

    private static Logger setupLogger(String str) {
        Logger logger = Logger.getLogger("AutomatedDataMining");
        Properties properties = new Properties();
        properties.setProperty("log4j.rootLogger", str + ", A1");
        properties.setProperty("log4j.appender.A1", "org.apache.log4j.ConsoleAppender");
        properties.setProperty("log4j.appender.A1.layout", "org.apache.log4j.PatternLayout");
        properties.setProperty("log4j.appender.A1.layout.ConversionPattern", "%d{ABSOLUTE};%m%n");
        PropertyConfigurator.configure(properties);
        return logger;
    }

    private static CommandParse parseInput(String[] strArr) {
        CommandParse commandParse = new CommandParse();
        commandParse.saveFlagValue("-template");
        commandParse.saveFlagValue("-data");
        commandParse.saveFlagValue("-testData");
        commandParse.saveFlagValue("-outputPath");
        commandParse.saveFlagValue("-folds");
        commandParse.saveFlagValue("-idIndex");
        commandParse.parse(strArr);
        return commandParse;
    }
}
