package edu.pitt.dbmi.edda.operator.ldaop;

import cc.mallet.pipe.Pipe;
import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.topics.TopicInferencer;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.learner.SimpleBinaryPredictionModel;
import edu.pitt.dbmi.edda.operator.c45bayes.C45BayesModel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.TreeSet;
import java.util.logging.Logger;

/* loaded from: input_file:edu/pitt/dbmi/edda/operator/ldaop/TopicModelAdapter.class */
public class TopicModelAdapter extends SimpleBinaryPredictionModel {
    private static final long serialVersionUID = 7307611030011692827L;
    private ParallelTopicModel parallelTopicModel;
    private Pipe pipe;
    private static final Logger logger = Logger.getLogger(TopicModelAdapter.class.getPackage().getName());
    private int numberOfRows;
    private int numberOfCols;
    private double excludeCount;
    private double includeCount;
    private double[] medianExcludeProbabilities;
    private double[] medianIncludeProbabilities;
    private double[] averageExcludeProbabilities;
    private double[] averageIncludeProbabilities;
    private Attribute labelAttribute;
    private int numberMostProbableWordsForDisplay;

    public TopicModelAdapter(ParallelTopicModel parallelTopicModel, ExampleSet exampleSet) {
        super(exampleSet, 0.5d);
        this.numberOfRows = 0;
        this.numberOfCols = 0;
        this.excludeCount = 0.0d;
        this.includeCount = 0.0d;
        logger.fine("Constructing a TopicModelAdapter");
        this.labelAttribute = exampleSet.getAttributes().getLabel();
        this.parallelTopicModel = parallelTopicModel;
        this.numberOfRows = parallelTopicModel.getData().size();
        this.numberOfCols = parallelTopicModel.getNumTopics();
        calculateAverageDivergences();
        calculateMediumDivergences();
    }

    private void calculateMediumDivergences() {
        int i = (int) this.excludeCount;
        int i2 = (int) this.includeCount;
        this.medianExcludeProbabilities = calculateMedians(C45BayesModel.CONST_C45_BAYES_MODEL_CLS_EXCLUDE, i);
        this.medianIncludeProbabilities = calculateMedians(C45BayesModel.CONST_C45_BAYES_MODEL_CLS_INCLUDE, i2);
        displayVector("Median exclude probs", this.medianExcludeProbabilities);
        displayVector("Median include probs", this.medianIncludeProbabilities);
    }

    private double[] calculateMedians(String str, int i) {
        ArrayList<ArrayList<Double>> createTargetFilteredMatrix = createTargetFilteredMatrix(i);
        TopicInferencer inferencer = this.parallelTopicModel.getInferencer();
        int i2 = 0;
        for (int i3 = 0; i3 < this.numberOfRows; i3++) {
            Instance instance = this.parallelTopicModel.getData().get(i3).instance;
            if (instance.getTarget().toString().equals(str)) {
                double[] inferTopicForInstance = TopicModelUtils.inferTopicForInstance(inferencer, instance, false);
                for (int i4 = 0; i4 < this.numberOfCols; i4++) {
                    createTargetFilteredMatrix.get(i4).set(i2, new Double(inferTopicForInstance[i4]));
                }
                i2++;
            }
        }
        double[] dArr = new double[this.numberOfCols];
        for (int i5 = 0; i5 < this.numberOfCols; i5++) {
            dArr[i5] = median(createTargetFilteredMatrix.get(i5));
        }
        return dArr;
    }

    private ArrayList<ArrayList<Double>> createTargetFilteredMatrix(int i) {
        ArrayList<ArrayList<Double>> arrayList = new ArrayList<>(this.numberOfCols);
        for (int i2 = 0; i2 < this.numberOfCols; i2++) {
            arrayList.add(makeZeroList(i));
        }
        return arrayList;
    }

    private ArrayList<Double> makeZeroList(int i) {
        ArrayList<Double> arrayList = new ArrayList<>();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(new Double(0.0d));
        }
        return arrayList;
    }

    private int countRowsWithLabel(String str) {
        int i = 0;
        int size = this.parallelTopicModel.getData().size();
        for (int i2 = 0; i2 < size; i2++) {
            if (this.parallelTopicModel.getData().get(i2).instance.getTarget().toString().equals(str)) {
                i++;
            }
        }
        return i;
    }

    private double calculateMedian(int i, String str) {
        TreeSet<IDSorter> treeSet = new TreeSet<>();
        int size = this.parallelTopicModel.getData().size();
        for (int i2 = 0; i2 < size; i2++) {
            Instance instance = this.parallelTopicModel.getData().get(i2).instance;
            if (instance.getTarget().toString().equals(str)) {
                treeSet.add(new IDSorter(treeSet.size(), TopicModelUtils.inferTopicForInstance(this.parallelTopicModel.getInferencer(), instance, false)[i]));
            }
        }
        return median(toArray(treeSet));
    }

    private double median(double[] dArr) {
        return dArr.length % 2 == 0 ? (dArr[(dArr.length / 2) - 1] + dArr[dArr.length / 2]) / 2.0d : dArr[dArr.length / 2];
    }

    private double median(ArrayList<Double> arrayList) {
        Collections.sort(arrayList);
        return arrayList.size() % 2 == 0 ? (arrayList.get((arrayList.size() / 2) - 1).doubleValue() + arrayList.get(arrayList.size() / 2).doubleValue()) / 2.0d : arrayList.get(arrayList.size() / 2).doubleValue();
    }

    private double[] toArray(TreeSet<IDSorter> treeSet) {
        double[] dArr = new double[treeSet.size()];
        int i = 0;
        Iterator<IDSorter> it = treeSet.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            dArr[i2] = it.next().getWeight();
        }
        return dArr;
    }

    private void calculateAverageDivergences() {
        this.averageExcludeProbabilities = new double[this.numberOfCols];
        this.averageIncludeProbabilities = new double[this.numberOfCols];
        for (int i = 0; i < this.numberOfRows; i++) {
            TopicInferencer inferencer = this.parallelTopicModel.getInferencer();
            Instance instance = this.parallelTopicModel.getData().get(i).instance;
            double[] inferTopicForInstance = TopicModelUtils.inferTopicForInstance(inferencer, instance, false);
            if (instance.getTarget().toString().startsWith("e")) {
                for (int i2 = 0; i2 < this.numberOfCols; i2++) {
                    double[] dArr = this.averageExcludeProbabilities;
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + inferTopicForInstance[i2];
                }
                this.excludeCount += 1.0d;
            } else {
                for (int i4 = 0; i4 < this.numberOfCols; i4++) {
                    double[] dArr2 = this.averageIncludeProbabilities;
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] + inferTopicForInstance[i4];
                }
                this.includeCount += 1.0d;
            }
        }
        for (int i6 = 0; i6 < this.numberOfCols; i6++) {
            double[] dArr3 = this.averageExcludeProbabilities;
            int i7 = i6;
            dArr3[i7] = dArr3[i7] / this.excludeCount;
            double[] dArr4 = this.averageIncludeProbabilities;
            int i8 = i6;
            dArr4[i8] = dArr4[i8] / this.includeCount;
        }
        displayVector("Average exclude probs", this.averageExcludeProbabilities);
        displayVector("Average include probs", this.averageIncludeProbabilities);
    }

    private void displayVector(String str, double[] dArr) {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append(str + ":");
        for (double d : dArr) {
            stringBuffer.append(d + ", ");
        }
        stringBuffer.append("\n");
        logger.fine(stringBuffer.toString());
    }

    public ParallelTopicModel getParallelTopicModel() {
        return this.parallelTopicModel;
    }

    private double[] extractTopicProbsFromExample(Example example) {
        double[] dArr = new double[this.numberOfCols];
        int i = 0;
        String buildColumnNameForTopic = TopicModelUtils.buildColumnNameForTopic(0);
        for (Attribute attribute : example.getAttributes()) {
            if (attribute.getName().equals(buildColumnNameForTopic)) {
                dArr[i] = example.getValue(attribute);
                i++;
                buildColumnNameForTopic = TopicModelUtils.buildColumnNameForTopic(i);
            }
        }
        return dArr;
    }

    public double predictWithKlDivergenceAgainstMeans(Example example) throws OperatorException {
        double[] extractTopicProbsFromExample = extractTopicProbsFromExample(example);
        displayVector(pullLabelFromExample(example), extractTopicProbsFromExample);
        double calculateDistance = calculateDistance(extractTopicProbsFromExample, this.averageExcludeProbabilities);
        double calculateDistance2 = calculateDistance(extractTopicProbsFromExample, this.averageIncludeProbabilities);
        double d = calculateDistance < calculateDistance2 ? 0.0d : 1.0d;
        String str = calculateDistance < calculateDistance2 ? "EXCLUDE" : "INCLUDE";
        logger.fine("\n\tExclude Distance: " + calculateDistance);
        logger.fine("\n\tInclude Distance: " + calculateDistance2);
        logger.fine("\n\tResult: " + str);
        return d;
    }

    public double predict(Example example) throws OperatorException {
        double[] extractTopicProbsFromExample = extractTopicProbsFromExample(example);
        displayVector(pullLabelFromExample(example), extractTopicProbsFromExample);
        double calculateExcludeDistance = calculateExcludeDistance(extractTopicProbsFromExample);
        double calculateIncludeDistance = calculateIncludeDistance(extractTopicProbsFromExample);
        double d = calculateExcludeDistance < calculateIncludeDistance ? 0.0d : 1.0d;
        String str = calculateExcludeDistance < calculateIncludeDistance ? "EXCLUDE" : "INCLUDE";
        logger.fine("\n\tExclude Distance: " + calculateExcludeDistance);
        logger.fine("\n\tInclude Distance: " + calculateIncludeDistance);
        logger.fine("\n\tResult: " + str);
        return d;
    }

    private String pullLabelFromExample(Example example) {
        String str = "unknown";
        try {
            str = example.getNominalValue(this.labelAttribute);
        } catch (Exception e) {
        }
        return str;
    }

    public double getExcludeDistance(Example example) {
        return calculateExcludeDistance(extractTopicProbsFromExample(example));
    }

    public double getIncludeDistance(Example example) {
        return calculateIncludeDistance(extractTopicProbsFromExample(example));
    }

    public double calculateIncludeDistance(double[] dArr) {
        return CalculationsKullbackLeibler.tryCalculateSymmetricKullbackLeiblerDivergence(dArr, this.medianIncludeProbabilities);
    }

    public double calculateExcludeDistance(double[] dArr) {
        return CalculationsKullbackLeibler.tryCalculateSymmetricKullbackLeiblerDivergence(dArr, this.medianExcludeProbabilities);
    }

    private double calculateDistance(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            double d2 = dArr[i] - dArr2[i];
            d += d2 * d2;
        }
        return Math.sqrt(d);
    }

    public void setNumberMostProbableWordsForDisplay(int i) {
        logger.fine("Entering setNumberMostProbableWordsForDisplay: " + i);
        this.numberMostProbableWordsForDisplay = i;
        if (this.parallelTopicModel != null) {
            this.numberMostProbableWordsForDisplay = Math.min(i, this.parallelTopicModel.getAlphabet().size());
        }
        logger.fine("Existing setNumberMostProbableWordsForDisplay: " + this.numberMostProbableWordsForDisplay);
    }

    public int getNumberMostProbableWordsForDisplay() {
        return this.numberMostProbableWordsForDisplay;
    }

    public Pipe getPipe() {
        return this.pipe;
    }

    public void setPipe(Pipe pipe) {
        this.pipe = pipe;
    }
}
