package org.data2semantics.mustard.kernels.graphkernels.graphlist;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.data2semantics.mustard.kernels.ComputationTimeTracker;
import org.data2semantics.mustard.kernels.FeatureInspector;
import org.data2semantics.mustard.kernels.KernelUtils;
import org.data2semantics.mustard.kernels.SparseVector;
import org.data2semantics.mustard.kernels.data.GraphList;
import org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel;
import org.data2semantics.mustard.kernels.graphkernels.GraphKernel;
import org.data2semantics.mustard.weisfeilerlehman.StringLabel;
import org.data2semantics.mustard.weisfeilerlehman.WLUtils;
import org.data2semantics.mustard.weisfeilerlehman.WeisfeilerLehmanDTGraphIterator;
import org.nodes.DTGraph;
import org.nodes.DTLink;
import org.nodes.DTNode;
import org.nodes.LightDTGraph;
import org.nodes.TNode;

/* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/graphlist/WLSubTreeKernel.class */
public class WLSubTreeKernel implements GraphKernel<GraphList<DTGraph<String, String>>>, FeatureVectorKernel<GraphList<DTGraph<String, String>>>, ComputationTimeTracker, FeatureInspector {
    private int iterations;
    protected boolean normalize;
    private boolean reverse;
    private boolean trackPrevNBH;
    private long compTime;
    private Map<String, String> dict;

    public WLSubTreeKernel(int i, boolean z, boolean z2, boolean z3) {
        this.reverse = z;
        this.trackPrevNBH = z2;
        this.normalize = z3;
        this.iterations = i;
    }

    public WLSubTreeKernel(int i, boolean z, boolean z2) {
        this(i, z, false, z2);
    }

    public WLSubTreeKernel(int i, boolean z) {
        this(i, false, false, z);
    }

    public WLSubTreeKernel(int i) {
        this(i, true);
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public String getLabel() {
        return KernelUtils.createLabel(this);
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    @Override // org.data2semantics.mustard.kernels.ComputationTimeTracker
    public long getComputationTime() {
        return this.compTime;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel
    public SparseVector[] computeFeatureVectors(GraphList<DTGraph<String, String>> graphList) {
        List<DTGraph<StringLabel, StringLabel>> copyGraphs = copyGraphs(graphList.getGraphs());
        SparseVector[] sparseVectorArr = new SparseVector[copyGraphs.size()];
        for (int i = 0; i < sparseVectorArr.length; i++) {
            sparseVectorArr[i] = new SparseVector();
        }
        WeisfeilerLehmanDTGraphIterator weisfeilerLehmanDTGraphIterator = new WeisfeilerLehmanDTGraphIterator(this.reverse, this.trackPrevNBH);
        long currentTimeMillis = System.currentTimeMillis();
        weisfeilerLehmanDTGraphIterator.wlInitialize(copyGraphs);
        computeFVs(copyGraphs, sparseVectorArr, 1.0d, weisfeilerLehmanDTGraphIterator.getLabelDict().size() - 1);
        for (int i2 = 0; i2 < this.iterations; i2++) {
            weisfeilerLehmanDTGraphIterator.wlIterate(copyGraphs);
            computeFVs(copyGraphs, sparseVectorArr, 1.0d, weisfeilerLehmanDTGraphIterator.getLabelDict().size() - 1);
        }
        this.compTime = System.currentTimeMillis() - currentTimeMillis;
        this.dict = new HashMap();
        for (String str : weisfeilerLehmanDTGraphIterator.getLabelDict().keySet()) {
            this.dict.put(weisfeilerLehmanDTGraphIterator.getLabelDict().get(str), str);
        }
        if (this.normalize) {
            sparseVectorArr = KernelUtils.normalize(sparseVectorArr);
        }
        return sparseVectorArr;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.GraphKernel
    public double[][] compute(GraphList<DTGraph<String, String>> graphList) {
        return KernelUtils.computeKernelMatrix(computeFeatureVectors(graphList), KernelUtils.initMatrix(graphList.getGraphs().size(), graphList.getGraphs().size()));
    }

    private void computeFVs(List<DTGraph<StringLabel, StringLabel>> list, SparseVector[] sparseVectorArr, double d, int i) {
        for (int i2 = 0; i2 < list.size(); i2++) {
            sparseVectorArr[i2].setLastIndex(i);
            for (DTNode<StringLabel, StringLabel> dTNode : list.get(i2).nodes()) {
                if (!dTNode.label().isSameAsPrev()) {
                    int parseInt = Integer.parseInt(dTNode.label().toString());
                    sparseVectorArr[i2].setValue(parseInt, sparseVectorArr[i2].getValue(parseInt) + d);
                }
            }
            for (DTLink<StringLabel, StringLabel> dTLink : list.get(i2).links()) {
                if (!dTLink.tag().isSameAsPrev()) {
                    int parseInt2 = Integer.parseInt(dTLink.tag().toString());
                    sparseVectorArr[i2].setValue(parseInt2, sparseVectorArr[i2].getValue(parseInt2) + d);
                }
            }
        }
    }

    private List<DTGraph<StringLabel, StringLabel>> copyGraphs(List<DTGraph<String, String>> list) {
        ArrayList arrayList = new ArrayList();
        for (DTGraph<String, String> dTGraph : list) {
            LightDTGraph lightDTGraph = new LightDTGraph();
            Iterator<? extends DTNode<String, String>> it = dTGraph.nodes().iterator();
            while (it.hasNext()) {
                lightDTGraph.add((LightDTGraph) new StringLabel(it.next().label()));
            }
            for (DTLink<String, String> dTLink : dTGraph.links()) {
                ((DTNode) lightDTGraph.nodes().get(dTLink.from().index())).connect((TNode<L, TNode>) lightDTGraph.nodes().get(dTLink.to().index()), (TNode) new StringLabel(dTLink.tag()));
            }
            arrayList.add(lightDTGraph);
        }
        return arrayList;
    }

    @Override // org.data2semantics.mustard.kernels.FeatureInspector
    public List<String> getFeatureDescriptions(List<Integer> list) {
        if (this.dict == null) {
            throw new RuntimeException("Should run computeFeatureVectors() first");
        }
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(WLUtils.getFeatureDecription(this.dict, it.next().intValue()));
        }
        return arrayList;
    }
}
