package com.rapidminer.optimplugin;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.Iterator;
import java.util.Random;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:com/rapidminer/optimplugin/OptimPluginUtil.class */
public class OptimPluginUtil {
    public static RealMatrix exampleSet2DataMatrix(ExampleSet exampleSet) {
        int size = exampleSet.size();
        int size2 = exampleSet.getAttributes().size();
        RealMatrix createRealMatrix = MatrixUtils.createRealMatrix(size, size2 + 1);
        int i = 0;
        Iterator it = exampleSet.iterator();
        while (it.hasNext()) {
            Example example = (Example) it.next();
            createRealMatrix.setEntry(i, 0, 1.0d);
            int i2 = 1;
            Iterator it2 = exampleSet.getAttributes().iterator();
            while (it2.hasNext()) {
                createRealMatrix.setEntry(i, i2, example.getNumericalValue((Attribute) it2.next()));
                i2++;
            }
            i++;
        }
        double[] dArr = new double[size2 + 1];
        for (int i3 = 0; i3 < size2 + 1; i3++) {
            dArr[i3] = 0.0d;
        }
        return createRealMatrix;
    }

    public static RealVector getLabelVector(ExampleSet exampleSet) {
        ArrayRealVector arrayRealVector = new ArrayRealVector(exampleSet.size());
        Attribute label = exampleSet.getAttributes().getLabel();
        int i = 0;
        if (label.getValueType() == 6) {
            Iterator it = exampleSet.iterator();
            while (it.hasNext()) {
                arrayRealVector.setEntry(i, ((Example) it.next()).getValue(label) == CMAESOptimizer.DEFAULT_STOPFITNESS ? -1.0d : 1.0d);
                i++;
            }
        } else {
            Iterator it2 = exampleSet.iterator();
            while (it2.hasNext()) {
                arrayRealVector.setEntry(i, ((Example) it2.next()).getValue(label));
                i++;
            }
        }
        return arrayRealVector;
    }

    public static RealVector initializeTheta(ExampleSet exampleSet) {
        int size = exampleSet.getAttributes().size();
        double[] dArr = new double[size + 1];
        for (int i = 0; i < size + 1; i++) {
            dArr[i] = 0.0d;
        }
        return new ArrayRealVector(dArr);
    }

    public static void shuffleRows(RealMatrix realMatrix, RealVector realVector, Random random) {
        for (int rowDimension = realMatrix.getRowDimension() - 1; rowDimension > 0; rowDimension--) {
            int nextInt = random.nextInt(rowDimension + 1);
            RealVector rowVector = realMatrix.getRowVector(nextInt);
            double entry = realVector.getEntry(nextInt);
            realMatrix.setRowVector(nextInt, realMatrix.getRowVector(rowDimension));
            realVector.setEntry(nextInt, realVector.getEntry(rowDimension));
            realMatrix.setRowVector(rowDimension, rowVector);
            realVector.setEntry(rowDimension, entry);
        }
    }

    public static void writeProb(File file, RealVector realVector) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(file));
        DecimalFormat decimalFormat = new DecimalFormat("#.#########");
        for (int i = 0; i < realVector.getDimension(); i++) {
            bufferedWriter.write(decimalFormat.format(realVector.getEntry(i)) + "\n");
        }
        bufferedWriter.flush();
        bufferedWriter.close();
    }
}
