package org.data2semantics.mustard.kernels.graphkernels.graphlist;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.data2semantics.mustard.kernels.ComputationTimeTracker;
import org.data2semantics.mustard.kernels.FeatureInspector;
import org.data2semantics.mustard.kernels.KernelUtils;
import org.data2semantics.mustard.kernels.SparseVector;
import org.data2semantics.mustard.kernels.data.GraphList;
import org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel;
import org.data2semantics.mustard.kernels.graphkernels.GraphKernel;
import org.data2semantics.mustard.utils.WalkCountUtils;
import org.nodes.DTGraph;
import org.nodes.DTLink;
import org.nodes.DTNode;
import org.nodes.LightDTGraph;
import org.nodes.TNode;

/* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/graphlist/WalkCountKernel.class */
public class WalkCountKernel implements GraphKernel<GraphList<DTGraph<String, String>>>, FeatureVectorKernel<GraphList<DTGraph<String, String>>>, ComputationTimeTracker, FeatureInspector {
    private int depth;
    private long compTime;
    protected boolean normalize;
    private Map<String, Integer> pathDict;
    private Map<String, Integer> labelDict;
    private Map<Integer, String> reversePathDict;
    private Map<Integer, String> reverseLabelDict;

    public WalkCountKernel(int i, boolean z) {
        this.normalize = z;
        this.depth = i;
    }

    public WalkCountKernel(int i) {
        this(i, true);
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public String getLabel() {
        return KernelUtils.createLabel(this);
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    @Override // org.data2semantics.mustard.kernels.ComputationTimeTracker
    public long getComputationTime() {
        return this.compTime;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel
    public SparseVector[] computeFeatureVectors(GraphList<DTGraph<String, String>> graphList) {
        this.pathDict = new HashMap();
        this.labelDict = new HashMap();
        List<DTGraph<String, String>> copyGraphs = copyGraphs(graphList.getGraphs());
        SparseVector[] sparseVectorArr = new SparseVector[copyGraphs.size()];
        for (int i = 0; i < sparseVectorArr.length; i++) {
            sparseVectorArr[i] = new SparseVector();
        }
        long currentTimeMillis = System.currentTimeMillis();
        for (int i2 = 0; i2 < sparseVectorArr.length; i2++) {
            Iterator<? extends DTNode<String, String>> it = copyGraphs.get(i2).nodes().iterator();
            while (it.hasNext()) {
                countPathRec(sparseVectorArr[i2], it.next(), "", this.depth);
            }
            Iterator<? extends DTLink<String, String>> it2 = copyGraphs.get(i2).links().iterator();
            while (it2.hasNext()) {
                countPathRec(sparseVectorArr[i2], it2.next(), "", this.depth);
            }
        }
        for (SparseVector sparseVector : sparseVectorArr) {
            sparseVector.setLastIndex(this.pathDict.size() - 1);
        }
        this.compTime = System.currentTimeMillis() - currentTimeMillis;
        this.reversePathDict = new HashMap();
        for (String str : this.pathDict.keySet()) {
            this.reversePathDict.put(this.pathDict.get(str), str);
        }
        this.reverseLabelDict = new HashMap();
        for (String str2 : this.labelDict.keySet()) {
            this.reverseLabelDict.put(this.labelDict.get(str2), str2);
        }
        if (this.normalize) {
            sparseVectorArr = KernelUtils.normalize(sparseVectorArr);
        }
        return sparseVectorArr;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.GraphKernel
    public double[][] compute(GraphList<DTGraph<String, String>> graphList) {
        return KernelUtils.computeKernelMatrix(computeFeatureVectors(graphList), KernelUtils.initMatrix(graphList.getGraphs().size(), graphList.getGraphs().size()));
    }

    private void countPathRec(SparseVector sparseVector, DTNode<String, String> dTNode, String str, int i) {
        String str2 = String.valueOf(str) + dTNode.label();
        if (!this.pathDict.containsKey(str2)) {
            this.pathDict.put(str2, Integer.valueOf(this.pathDict.size()));
        }
        sparseVector.setValue(this.pathDict.get(str2).intValue(), sparseVector.getValue(this.pathDict.get(str2).intValue()) + 1.0d);
        if (i > 0) {
            Iterator<? extends DTLink<String, String>> it = dTNode.linksOut().iterator();
            while (it.hasNext()) {
                countPathRec(sparseVector, it.next(), str2, i - 1);
            }
        }
    }

    private void countPathRec(SparseVector sparseVector, DTLink<String, String> dTLink, String str, int i) {
        String str2 = String.valueOf(str) + dTLink.tag();
        if (!this.pathDict.containsKey(str2)) {
            this.pathDict.put(str2, Integer.valueOf(this.pathDict.size()));
        }
        sparseVector.setValue(this.pathDict.get(str2).intValue(), sparseVector.getValue(this.pathDict.get(str2).intValue()) + 1.0d);
        if (i > 0) {
            countPathRec(sparseVector, dTLink.to(), str2, i - 1);
        }
    }

    private List<DTGraph<String, String>> copyGraphs(List<DTGraph<String, String>> list) {
        ArrayList arrayList = new ArrayList();
        for (DTGraph<String, String> dTGraph : list) {
            LightDTGraph lightDTGraph = new LightDTGraph();
            for (DTNode<String, String> dTNode : dTGraph.nodes()) {
                if (!this.labelDict.containsKey(dTNode.label())) {
                    this.labelDict.put(dTNode.label(), Integer.valueOf(this.labelDict.size()));
                }
                lightDTGraph.add((LightDTGraph) ("_" + Integer.toString(this.labelDict.get(dTNode.label()).intValue())));
            }
            for (DTLink<String, String> dTLink : dTGraph.links()) {
                if (!this.labelDict.containsKey(dTLink.tag())) {
                    this.labelDict.put(dTLink.tag(), Integer.valueOf(this.labelDict.size()));
                }
                ((DTNode) lightDTGraph.nodes().get(dTLink.from().index())).connect((TNode<L, TNode>) lightDTGraph.nodes().get(dTLink.to().index()), (TNode) ("_" + Integer.toString(this.labelDict.get(dTLink.tag()).intValue())));
            }
            arrayList.add(lightDTGraph);
        }
        return arrayList;
    }

    @Override // org.data2semantics.mustard.kernels.FeatureInspector
    public List<String> getFeatureDescriptions(List<Integer> list) {
        if (this.labelDict == null) {
            throw new RuntimeException("Should run computeFeatureVectors first");
        }
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(WalkCountUtils.getFeatureDecription(this.reverseLabelDict, this.reversePathDict, it.next().intValue()));
        }
        return arrayList;
    }
}
