package game.models.single.rbf;

import configuration.classifiers.single.weka.RBFNormalizationScale;
import configuration.models.ModelConfig;
import configuration.models.single.RBFModelConfig;
import game.cSerialization.CCodeUtils;
import game.cSerialization.XMLBuildUtils;
import game.clusters.ArrayKMeans;
import game.evolution.treeEvolution.exception.LearnException;
import game.models.ModelLearnableBase;
import game.tools.FastMath;
import game.utils.Utils;
import java.util.Random;
import org.ytoh.configurations.ui.SelectionSetModel;
import utils.UtilsCommon;
import weka.core.RevisionUtils;

/* loaded from: input_file:game/models/single/rbf/RBFModel.class */
public class RBFModel extends ModelLearnableBase {
    public static final int USE_GLOBAL_SCALE = 1;
    public static final int USE_SCALE_PER_UNIT = 2;
    public static final int USE_SCALE_PER_UNIT_AND_ATTRIBUTE = 3;
    protected int m_scaleOptimizationOption = 2;
    protected int m_numUnits = 2;
    protected double[] m_RBFParameters = null;
    protected double m_ridge = 0.01d;
    protected boolean m_useNormalizedBasisFunctions = false;
    protected boolean m_useAttributeWeights = false;
    protected int OFFSET_WEIGHTS = -1;
    protected int OFFSET_SCALES = -1;
    protected int OFFSET_CENTERS = -1;
    protected int OFFSET_ATTRIBUTE_WEIGHTS = -1;
    private int maxOptimizationSteps = 100;
    private SelectionSetModel<RBFNormalizationScale> rbfScale;

    /* loaded from: input_file:game/models/single/rbf/RBFModel$OptEng.class */
    protected class OptEng extends RBFOptimization {
        protected OptEng() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // weka.core.Optimization
        public double objectiveFunction(double[] dArr) {
            RBFModel.this.m_RBFParameters = dArr;
            return RBFModel.this.calculateSE();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // weka.core.Optimization
        public double[] evaluateGradient(double[] dArr) {
            RBFModel.this.m_RBFParameters = dArr;
            return RBFModel.this.calculateGradient();
        }

        @Override // weka.core.RevisionHandler
        public String getRevision() {
            return RevisionUtils.extract("$Revision: 8109 $");
        }
    }

    /* loaded from: input_file:game/models/single/rbf/RBFModel$OptEngCGD.class */
    protected class OptEngCGD extends RBFConjugateOptimization {
        protected OptEngCGD() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // weka.core.Optimization
        public double objectiveFunction(double[] dArr) {
            RBFModel.this.m_RBFParameters = dArr;
            return RBFModel.this.calculateSE();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // weka.core.Optimization
        public double[] evaluateGradient(double[] dArr) {
            RBFModel.this.m_RBFParameters = dArr;
            return RBFModel.this.calculateGradient();
        }

        @Override // weka.core.RevisionHandler
        public String getRevision() {
            return RevisionUtils.extract("$Revision: 8109 $");
        }
    }

    @Override // game.models.ModelLearnableBase, game.models.ModelLearnable
    public void init(ModelConfig modelConfig) {
        super.init(modelConfig);
        RBFModelConfig rBFModelConfig = (RBFModelConfig) modelConfig;
        this.rbfScale = UtilsCommon.cloneSelectionSet(rBFModelConfig.getRbfScale());
        this.m_ridge = rBFModelConfig.getRidge();
        this.m_numUnits = rBFModelConfig.getNeurons();
        this.m_scaleOptimizationOption = this.rbfScale.getEnabledElements(RBFNormalizationScale.class)[0].getType();
    }

    @Override // game.models.ModelLearnableBase, game.models.Model
    public ModelConfig getConfig() {
        RBFModelConfig rBFModelConfig = (RBFModelConfig) super.getConfig();
        rBFModelConfig.setRbfScale(UtilsCommon.cloneSelectionSet(this.rbfScale));
        rBFModelConfig.setNeurons(this.m_numUnits);
        rBFModelConfig.setRidge(this.m_ridge);
        return rBFModelConfig;
    }

    private double[][] createClusters() {
        double[][] dArr = new double[this.learning_vectors][this.inputsNumber];
        for (int i = 0; i < this.learning_vectors; i++) {
            System.arraycopy(this.inputVect[i], 0, dArr[i], 0, this.inputsNumber);
        }
        ArrayKMeans arrayKMeans = new ArrayKMeans(dArr, this.m_numUnits);
        arrayKMeans.setClusterSizeMultiplier(1.0d);
        arrayKMeans.run();
        double[][] centroids = arrayKMeans.getCentroids();
        if (centroids.length < this.m_numUnits) {
            this.m_numUnits = centroids.length;
        }
        return centroids;
    }

    protected double[] getMinDistances(double[][] dArr) {
        double[][] dArr2 = new double[dArr.length][dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = i + 1; i2 < dArr.length; i2++) {
                double d = 0.0d;
                for (int i3 = 0; i3 < dArr[i].length; i3++) {
                    double d2 = dArr[i][i3] - dArr[i2][i3];
                    d += d2 * d2;
                }
                dArr2[i][i2] = d;
                dArr2[i2][i] = d;
            }
        }
        double[] dArr3 = new double[dArr2.length];
        for (int i4 = 0; i4 < dArr2.length; i4++) {
            dArr2[i4][i4] = Double.MAX_VALUE;
            dArr3[i4] = Math.sqrt(Utils.min(dArr2[i4]));
        }
        return dArr3;
    }

    protected void initializeClassifier() {
        double[][] createClusters = createClusters();
        this.OFFSET_WEIGHTS = 0;
        if (this.m_useAttributeWeights) {
            this.OFFSET_ATTRIBUTE_WEIGHTS = this.m_numUnits + 1;
            this.OFFSET_CENTERS = this.OFFSET_ATTRIBUTE_WEIGHTS + this.inputsNumber;
        } else {
            this.OFFSET_ATTRIBUTE_WEIGHTS = -1;
            this.OFFSET_CENTERS = this.m_numUnits + 1;
        }
        this.OFFSET_SCALES = this.OFFSET_CENTERS + (this.m_numUnits * this.inputsNumber);
        switch (this.m_scaleOptimizationOption) {
            case 1:
                this.m_RBFParameters = new double[this.OFFSET_SCALES + 1];
                break;
            case 3:
                this.m_RBFParameters = new double[this.OFFSET_SCALES + (this.m_numUnits * this.inputsNumber)];
                break;
            default:
                this.m_RBFParameters = new double[this.OFFSET_SCALES + this.m_numUnits];
                break;
        }
        double[] minDistances = getMinDistances(createClusters);
        if (this.m_scaleOptimizationOption == 1) {
            this.m_RBFParameters[this.OFFSET_SCALES] = Utils.max(minDistances);
        } else if (this.m_scaleOptimizationOption == 2) {
            for (int i = 0; i < minDistances.length; i++) {
                this.m_RBFParameters[this.OFFSET_SCALES + i] = minDistances[i];
            }
        } else if (this.m_scaleOptimizationOption == 3) {
            for (int i2 = 0; i2 < createClusters.length; i2++) {
                for (int i3 = 0; i3 < createClusters[i2].length; i3++) {
                    this.m_RBFParameters[this.OFFSET_SCALES + (i2 * this.inputsNumber) + i3] = minDistances[i2];
                }
            }
        }
        Random random = new Random();
        for (int i4 = 0; i4 < createClusters.length; i4++) {
            this.m_RBFParameters[this.OFFSET_WEIGHTS + i4] = (random.nextDouble() - 0.5d) / 2.0d;
            for (int i5 = 0; i5 < createClusters[i4].length; i5++) {
                this.m_RBFParameters[this.OFFSET_CENTERS + (i4 * this.inputsNumber) + i5] = createClusters[i4][i5];
            }
        }
        if (this.m_useAttributeWeights) {
            for (int i6 = 0; i6 < this.inputsNumber; i6++) {
                this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + i6] = 1.0d;
            }
        }
        this.m_RBFParameters[this.OFFSET_WEIGHTS + this.m_numUnits] = ((random.nextDouble() - 0.5d) / 2.0d) + 0.5d;
    }

    @Override // game.models.ModelLearnable
    public void learn() {
        initializeClassifier();
        double[][] dArr = new double[2][this.m_RBFParameters.length];
        for (int i = 0; i < 2; i++) {
            for (int i2 = 0; i2 < this.m_RBFParameters.length; i2++) {
                dArr[i][i2] = Double.NaN;
            }
        }
        OptEngCGD optEngCGD = new OptEngCGD();
        try {
            optEngCGD.setMaxIterations(this.maxOptimizationSteps);
            this.m_RBFParameters = optEngCGD.findArgmin(this.m_RBFParameters, dArr);
            postLearnActions();
        } catch (Exception e) {
            throw new LearnException(e.getMessage());
        }
    }

    protected double calculateSE() {
        double d = 0.0d;
        for (int i = 0; i < this.learning_vectors; i++) {
            double output = getOutput(this.inputVect[i]) - this.target[i];
            d += output * output;
        }
        double d2 = 0.0d;
        for (int i2 = 0; i2 < this.m_numUnits; i2++) {
            d2 += this.m_RBFParameters[this.OFFSET_WEIGHTS + i2] * this.m_RBFParameters[this.OFFSET_WEIGHTS + i2];
        }
        return (this.m_ridge * d2) + (0.5d * d);
    }

    protected double[] calculateGradient() {
        double[] dArr = new double[this.m_RBFParameters.length];
        for (int i = 0; i < this.learning_vectors; i++) {
            double[] dArr2 = this.inputVect[i];
            double[] calculateNeuronOutputs = calculateNeuronOutputs(dArr2);
            double netOutput = getNetOutput(calculateNeuronOutputs) - this.target[i];
            for (int i2 = 0; i2 < this.m_numUnits; i2++) {
                int i3 = this.OFFSET_WEIGHTS + i2;
                dArr[i3] = dArr[i3] + (netOutput * calculateNeuronOutputs[i2]);
                switch (this.m_scaleOptimizationOption) {
                    case 1:
                        int i4 = this.OFFSET_SCALES;
                        dArr[i4] = dArr[i4] + derivativeOneScale(dArr, netOutput, this.m_RBFParameters[this.OFFSET_WEIGHTS + i2], calculateNeuronOutputs[i2], this.m_RBFParameters[this.OFFSET_SCALES], dArr2, i2);
                        break;
                    case 3:
                        derivativeScalePerAttribute(dArr, netOutput, this.m_RBFParameters[this.OFFSET_WEIGHTS + i2], calculateNeuronOutputs[i2], dArr2, i2);
                        break;
                    default:
                        int i5 = this.OFFSET_SCALES + i2;
                        dArr[i5] = dArr[i5] + derivativeOneScale(dArr, netOutput, this.m_RBFParameters[this.OFFSET_WEIGHTS + i2], calculateNeuronOutputs[i2], this.m_RBFParameters[this.OFFSET_SCALES + i2], dArr2, i2);
                        break;
                }
            }
            int i6 = this.OFFSET_WEIGHTS + this.m_numUnits;
            dArr[i6] = dArr[i6] + netOutput;
        }
        for (int i7 = 0; i7 < this.m_numUnits; i7++) {
            int i8 = this.OFFSET_WEIGHTS + i7;
            dArr[i8] = dArr[i8] + (this.m_ridge * 2.0d * this.m_RBFParameters[this.OFFSET_WEIGHTS + i7]);
        }
        return dArr;
    }

    protected void derivativeScalePerAttribute(double[] dArr, double d, double d2, double d3, double[] dArr2, int i) {
        double d4 = d * d2 * d3;
        if (this.m_useNormalizedBasisFunctions) {
            d4 *= 1.0d - d3;
        }
        int i2 = this.OFFSET_CENTERS + (i * this.inputsNumber);
        int i3 = this.OFFSET_SCALES + (i * this.inputsNumber);
        double d5 = 1.0d;
        for (int i4 = 0; i4 < dArr2.length; i4++) {
            double d6 = dArr2[i4] - this.m_RBFParameters[i2 + i4];
            double d7 = this.m_RBFParameters[i3 + i4];
            double d8 = d7 * d7;
            if (this.m_useAttributeWeights) {
                d5 = this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + i4] * this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + i4];
                int i5 = this.OFFSET_ATTRIBUTE_WEIGHTS + i4;
                dArr[i5] = dArr[i5] - ((((this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + i4] * d4) * d6) * d6) / d8);
            }
            int i6 = i3 + i4;
            dArr[i6] = dArr[i6] + ((((d4 * d5) * d6) * d6) / (d8 * d7));
            int i7 = i2 + i4;
            dArr[i7] = dArr[i7] + (((d4 * d5) * d6) / d8);
        }
    }

    protected double derivativeOneScale(double[] dArr, double d, double d2, double d3, double d4, double[] dArr2, int i) {
        double d5 = ((d * d2) * d3) / (d4 * d4);
        if (this.m_useNormalizedBasisFunctions) {
            d5 *= 1.0d - d3;
        }
        double d6 = 0.0d;
        int i2 = this.OFFSET_CENTERS + (i * this.inputsNumber);
        double d7 = 1.0d;
        for (int i3 = 0; i3 < dArr2.length; i3++) {
            double d8 = dArr2[i3] - this.m_RBFParameters[i2 + i3];
            double d9 = d8 * d8;
            if (this.m_useAttributeWeights) {
                int i4 = this.OFFSET_ATTRIBUTE_WEIGHTS + i3;
                d7 = this.m_RBFParameters[i4] * this.m_RBFParameters[i4];
                dArr[i4] = dArr[i4] - ((this.m_RBFParameters[i4] * d5) * d9);
            }
            d6 += d7 * d9;
            int i5 = i2 + i3;
            dArr[i5] = dArr[i5] + (d5 * d7 * d8);
        }
        return (d5 * d6) / d4;
    }

    @Override // game.models.Model
    public double getOutput(double[] dArr) {
        return getNetOutput(calculateNeuronOutputs(dArr));
    }

    protected double getNetOutput(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < this.m_numUnits; i++) {
            d += this.m_RBFParameters[this.OFFSET_WEIGHTS + i] * dArr[i];
        }
        return d + this.m_RBFParameters[this.OFFSET_WEIGHTS + this.m_numUnits];
    }

    protected double[] calculateNeuronOutputs(double[] dArr) {
        double[] dArr2 = new double[this.m_numUnits];
        for (int i = 0; i < this.m_numUnits; i++) {
            dArr2[i] = calculateNeuronOutput(dArr, i);
        }
        if (this.m_useNormalizedBasisFunctions) {
            logs2probs(dArr2);
        }
        return dArr2;
    }

    protected double calculateNeuronOutput(double[] dArr, int i) {
        double sumSquaredDiffOneScale;
        switch (this.m_scaleOptimizationOption) {
            case 1:
                sumSquaredDiffOneScale = sumSquaredDiffOneScale(this.m_RBFParameters[this.OFFSET_SCALES], dArr, i, this.inputsNumber);
                break;
            case 3:
                sumSquaredDiffOneScale = sumSquaredDiffScalePerAttribute(dArr, i, this.inputsNumber);
                break;
            default:
                sumSquaredDiffOneScale = sumSquaredDiffOneScale(this.m_RBFParameters[this.OFFSET_SCALES + i], dArr, i, this.inputsNumber);
                break;
        }
        return !this.m_useNormalizedBasisFunctions ? FastMath.exp(-sumSquaredDiffOneScale) : -sumSquaredDiffOneScale;
    }

    protected double sumSquaredDiffScalePerAttribute(double[] dArr, int i, int i2) {
        int i3 = this.OFFSET_SCALES + (i * i2);
        int i4 = this.OFFSET_CENTERS + (i * i2);
        double d = 0.0d;
        for (int i5 = 0; i5 < dArr.length; i5++) {
            double d2 = this.m_RBFParameters[i4 + i5] - dArr[i5];
            if (this.m_useAttributeWeights) {
                d2 *= this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + i5];
            }
            d += (d2 * d2) / ((2.0d * this.m_RBFParameters[i3 + i5]) * this.m_RBFParameters[i3 + i5]);
        }
        return d;
    }

    protected double sumSquaredDiffOneScale(double d, double[] dArr, int i, int i2) {
        int i3 = this.OFFSET_CENTERS + (i * i2);
        double d2 = 0.0d;
        for (int i4 = 0; i4 < dArr.length; i4++) {
            double d3 = this.m_RBFParameters[i3 + i4] - dArr[i4];
            if (this.m_useAttributeWeights) {
                d3 *= this.m_RBFParameters[this.OFFSET_ATTRIBUTE_WEIGHTS + i4];
            }
            d2 += d3 * d3;
        }
        return d2 / ((2.0d * d) * d);
    }

    public static void logs2probs(double[] dArr) {
        double max = Utils.max(dArr);
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = FastMath.exp(dArr[i] - max);
            d += dArr[i];
        }
        if (d != 0.0d) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] / d;
            }
        }
    }

    @Override // game.configuration.Configurable
    public Class getConfigClass() {
        return RBFModelConfig.class;
    }

    @Override // game.cSerialization.CSerialization
    public String toCCode(StringBuilder sb, StringBuilder sb2) {
        sb.append("#include \"").append(CCodeUtils.getRegressionModelPath()).append("RBFModel.h\"\n");
        String uniqueFunctionName = CCodeUtils.getUniqueFunctionName(getClass());
        CCodeUtils.getCRegressionHeader(uniqueFunctionName, this.inputsNumber, sb);
        CCodeUtils.convertArray(this.m_RBFParameters, "rbfParameters", sb);
        sb.append("return rbfModelOutput<").append(this.inputsNumber).append(",").append(this.m_RBFParameters.length).append(",").append(this.m_numUnits).append(">(input,rbfParameters,").append(this.OFFSET_WEIGHTS).append(",").append(this.OFFSET_SCALES).append(",").append(this.OFFSET_CENTERS).append(",").append(this.OFFSET_ATTRIBUTE_WEIGHTS).append(",").append(this.m_useNormalizedBasisFunctions).append(",").append(this.m_useAttributeWeights).append(",").append(this.m_scaleOptimizationOption).append(");\n");
        sb.append("}\n");
        XMLBuildUtils.outputXML(sb2, this, uniqueFunctionName);
        return uniqueFunctionName;
    }
}
