package cc.mallet.cluster.tui;

import cc.mallet.cluster.Clustering;
import cc.mallet.cluster.Clusterings;
import cc.mallet.cluster.util.ClusterUtils;
import cc.mallet.pipe.Noop;
import cc.mallet.types.Alphabet;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.util.CommandOption;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Randoms;
import com.meaningcloud.LangRequest;
import gnu.trove.TIntHashSet;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.logging.Logger;

/* loaded from: input_file:cc/mallet/cluster/tui/Clusterings2Clusterings.class */
public class Clusterings2Clusterings {
    private static Logger logger = MalletLogger.getLogger(Clusterings2Clusterings.class.getName());
    static CommandOption.String inputFile = new CommandOption.String(Clusterings2Clusterings.class, "input", "FILENAME", true, "text.clusterings", "The filename from which to read the list of instances.", null);
    static CommandOption.String outputPrefixFile = new CommandOption.String(Clusterings2Clusterings.class, "output-prefix", "FILENAME", false, "text.clusterings", "The filename prefix to write output. Suffices 'train' and 'test' appended.", null);
    static CommandOption.Integer minClusterSize = new CommandOption.Integer(Clusterings2Clusterings.class, "min-cluster-size", "INTEGER", false, 1, "Remove clusters with fewer than this many Instances.", null);
    static CommandOption.Double trainingProportion = new CommandOption.Double(Clusterings2Clusterings.class, "training-proportion", "DOUBLE", false, 0.0d, "Split into training and testing, with this percentage of instances reserved for training.", null);

    public static void main(String[] strArr) {
        CommandOption.setSummary(Clusterings2Clusterings.class, "A tool to manipulate Clusterings.");
        CommandOption.process(Clusterings2Clusterings.class, strArr);
        Clusterings clusterings = null;
        try {
            clusterings = (Clusterings) new ObjectInputStream(new FileInputStream(inputFile.value)).readObject();
        } catch (Exception e) {
            System.err.println("Exception reading clusterings from " + inputFile.value + LangRequest.DEFAULT_SELECTION + e);
            e.printStackTrace();
        }
        logger.info("number clusterings=" + clusterings.size());
        if (minClusterSize.value > 1) {
            for (int i = 0; i < clusterings.size(); i++) {
                Clustering clustering = clusterings.get(i);
                InstanceList instances = clustering.getInstances();
                Alphabet dataAlphabet = instances.getDataAlphabet();
                LabelAlphabet labelAlphabet = (LabelAlphabet) instances.getTargetAlphabet();
                if (dataAlphabet == null) {
                    dataAlphabet = new Alphabet();
                }
                if (labelAlphabet == null) {
                    labelAlphabet = new LabelAlphabet();
                }
                Noop noop = new Noop(dataAlphabet, labelAlphabet);
                InstanceList instanceList = new InstanceList(noop);
                for (int i2 = 0; i2 < instances.size(); i2++) {
                    int label = clustering.getLabel(i2);
                    Instance instance = instances.get(i2);
                    if (clustering.size(label) >= minClusterSize.value) {
                        instanceList.add(noop.pipe(new Instance(instance.getData(), labelAlphabet.lookupLabel(new Integer(label)), instance.getName(), instance.getSource())));
                    }
                }
                clusterings.set(i, createSmallerClustering(instanceList));
            }
            if (outputPrefixFile.value != null) {
                try {
                    ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(outputPrefixFile.value));
                    objectOutputStream.writeObject(clusterings);
                    objectOutputStream.close();
                } catch (Exception e2) {
                    logger.warning("Exception writing clustering to file " + outputPrefixFile.value + LangRequest.DEFAULT_SELECTION + e2);
                    e2.printStackTrace();
                }
            }
        }
        if (trainingProportion.value > 0.0d) {
            if (clusterings.size() > 1) {
                throw new IllegalArgumentException("Expect one clustering to do train/test split, not " + clusterings.size());
            }
            Clustering clustering2 = clusterings.get(0);
            int numInstances = (int) (trainingProportion.value * clustering2.getNumInstances());
            TIntHashSet tIntHashSet = new TIntHashSet();
            Randoms randoms = new Randoms(123);
            LabelAlphabet labelAlphabet2 = new LabelAlphabet();
            InstanceList instanceList2 = new InstanceList(new Noop(null, labelAlphabet2));
            while (instanceList2.size() < numInstances) {
                int nextInt = randoms.nextInt(clustering2.getNumClusters());
                if (!tIntHashSet.contains(nextInt)) {
                    tIntHashSet.add(nextInt);
                    InstanceList cluster = clustering2.getCluster(nextInt);
                    for (int i3 = 0; i3 < cluster.size(); i3++) {
                        Instance instance2 = cluster.get(i3);
                        instanceList2.add(new Instance(instance2.getData(), labelAlphabet2.lookupLabel(new Integer(nextInt)), instance2.getName(), instance2.getSource()));
                    }
                }
            }
            instanceList2.shuffle(randoms);
            Clustering createSmallerClustering = createSmallerClustering(instanceList2);
            InstanceList instanceList3 = new InstanceList((Alphabet) null, labelAlphabet2);
            for (int i4 = 0; i4 < clustering2.getNumClusters(); i4++) {
                if (!tIntHashSet.contains(i4)) {
                    InstanceList cluster2 = clustering2.getCluster(i4);
                    for (int i5 = 0; i5 < cluster2.size(); i5++) {
                        Instance instance3 = cluster2.get(i5);
                        instanceList3.add(new Instance(instance3.getData(), labelAlphabet2.lookupLabel(new Integer(i4)), instance3.getName(), instance3.getSource()));
                    }
                }
            }
            instanceList3.shuffle(randoms);
            Clustering createSmallerClustering2 = createSmallerClustering(instanceList3);
            logger.info(outputPrefixFile.value + ".train : " + createSmallerClustering.getNumClusters() + " objects");
            logger.info(outputPrefixFile.value + ".test : " + createSmallerClustering2.getNumClusters() + " objects");
            if (outputPrefixFile.value != null) {
                try {
                    ObjectOutputStream objectOutputStream2 = new ObjectOutputStream(new FileOutputStream(new File(outputPrefixFile.value + ".train")));
                    objectOutputStream2.writeObject(new Clusterings(new Clustering[]{createSmallerClustering}));
                    objectOutputStream2.close();
                    ObjectOutputStream objectOutputStream3 = new ObjectOutputStream(new FileOutputStream(new File(outputPrefixFile.value + ".test")));
                    objectOutputStream3.writeObject(new Clusterings(new Clustering[]{createSmallerClustering2}));
                    objectOutputStream3.close();
                } catch (Exception e3) {
                    logger.warning("Exception writing clustering to file " + outputPrefixFile.value + LangRequest.DEFAULT_SELECTION + e3);
                    e3.printStackTrace();
                }
            }
        }
    }

    private static Clustering createSmallerClustering(InstanceList instanceList) {
        return ClusterUtils.mergeInstancesWithSameLabel(ClusterUtils.createSingletonClustering(instanceList));
    }
}
