package org.data2semantics.mustard.kernels.graphkernels.singledtgraph;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.data2semantics.mustard.kernels.ComputationTimeTracker;
import org.data2semantics.mustard.kernels.FeatureInspector;
import org.data2semantics.mustard.kernels.KernelUtils;
import org.data2semantics.mustard.kernels.SparseVector;
import org.data2semantics.mustard.kernels.data.SingleDTGraph;
import org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel;
import org.data2semantics.mustard.kernels.graphkernels.GraphKernel;
import org.data2semantics.mustard.utils.WalkCountUtils;
import org.nodes.DTGraph;
import org.nodes.DTLink;
import org.nodes.DTNode;
import org.nodes.LightDTGraph;
import org.nodes.TNode;

/* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/singledtgraph/DTGraphTreeWalkCountKernel.class */
public class DTGraphTreeWalkCountKernel implements GraphKernel<SingleDTGraph>, FeatureVectorKernel<SingleDTGraph>, ComputationTimeTracker, FeatureInspector {
    private DTGraph<String, String> rdfGraph;
    private List<DTNode<String, String>> instanceVertices;
    private int pathLength;
    private int depth;
    private boolean normalize;
    private long compTime;
    private Map<String, Integer> pathDict;
    private Map<String, Integer> labelDict;
    private Map<Integer, String> reversePathDict;
    private Map<Integer, String> reverseLabelDict;

    public DTGraphTreeWalkCountKernel(int i, int i2, boolean z) {
        this.normalize = z;
        this.pathLength = i;
        this.depth = i2;
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public String getLabel() {
        return KernelUtils.createLabel(this);
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    @Override // org.data2semantics.mustard.kernels.ComputationTimeTracker
    public long getComputationTime() {
        return this.compTime;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel
    public SparseVector[] computeFeatureVectors(SingleDTGraph singleDTGraph) {
        this.pathDict = new HashMap();
        this.labelDict = new HashMap();
        init(singleDTGraph.getGraph(), singleDTGraph.getInstances());
        if (this.pathLength > this.depth * 2) {
            this.pathLength = this.depth * 2;
        }
        SparseVector[] sparseVectorArr = new SparseVector[singleDTGraph.numInstances()];
        for (int i = 0; i < sparseVectorArr.length; i++) {
            sparseVectorArr[i] = new SparseVector();
        }
        long currentTimeMillis = System.currentTimeMillis();
        for (int i2 = 0; i2 < sparseVectorArr.length; i2++) {
            countPathRec(sparseVectorArr[i2], this.instanceVertices.get(i2), "", this.pathLength);
            ArrayList<DTNode<String, String>> arrayList = new ArrayList();
            Iterator<? extends DTLink<String, String>> it = this.instanceVertices.get(i2).linksOut().iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().to());
            }
            ArrayList arrayList2 = new ArrayList();
            arrayList2.addAll(this.instanceVertices.get(i2).linksOut());
            int i3 = this.pathLength - 1;
            int i4 = this.depth;
            while (i4 > 0) {
                int min = Math.min((i4 * 2) - 1, i3);
                Iterator it2 = arrayList2.iterator();
                while (it2.hasNext()) {
                    countPathRec(sparseVectorArr[i2], (DTLink<String, String>) it2.next(), "", min);
                }
                arrayList2 = new ArrayList();
                ArrayList arrayList3 = new ArrayList();
                for (DTNode<String, String> dTNode : arrayList) {
                    countPathRec(sparseVectorArr[i2], dTNode, "", min - 1);
                    if (min - 1 > 0) {
                        arrayList2.addAll(dTNode.linksOut());
                        Iterator<? extends DTLink<String, String>> it3 = dTNode.linksOut().iterator();
                        while (it3.hasNext()) {
                            arrayList3.add(it3.next().to());
                        }
                    }
                }
                arrayList = arrayList3;
                i4--;
                i3 = min - 2;
            }
        }
        this.compTime = System.currentTimeMillis() - currentTimeMillis;
        this.reversePathDict = new HashMap();
        for (String str : this.pathDict.keySet()) {
            this.reversePathDict.put(this.pathDict.get(str), str);
        }
        this.reverseLabelDict = new HashMap();
        for (String str2 : this.labelDict.keySet()) {
            this.reverseLabelDict.put(this.labelDict.get(str2), str2);
        }
        if (this.normalize) {
            sparseVectorArr = KernelUtils.normalize(sparseVectorArr);
        }
        return sparseVectorArr;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.GraphKernel
    public double[][] compute(SingleDTGraph singleDTGraph) {
        SparseVector[] computeFeatureVectors = computeFeatureVectors(singleDTGraph);
        double[][] initMatrix = KernelUtils.initMatrix(singleDTGraph.getInstances().size(), singleDTGraph.getInstances().size());
        long currentTimeMillis = System.currentTimeMillis();
        double[][] computeKernelMatrix = KernelUtils.computeKernelMatrix(computeFeatureVectors, initMatrix);
        this.compTime += System.currentTimeMillis() - currentTimeMillis;
        return computeKernelMatrix;
    }

    private void countPathRec(SparseVector sparseVector, DTNode<String, String> dTNode, String str, int i) {
        String str2 = String.valueOf(str) + dTNode.label();
        if (!this.pathDict.containsKey(str2)) {
            this.pathDict.put(str2, Integer.valueOf(this.pathDict.size()));
        }
        sparseVector.setValue(this.pathDict.get(str2).intValue(), sparseVector.getValue(this.pathDict.get(str2).intValue()) + 1.0d);
        if (i > 0) {
            Iterator<? extends DTLink<String, String>> it = dTNode.linksOut().iterator();
            while (it.hasNext()) {
                countPathRec(sparseVector, it.next(), str2, i - 1);
            }
        }
    }

    private void countPathRec(SparseVector sparseVector, DTLink<String, String> dTLink, String str, int i) {
        String str2 = String.valueOf(str) + dTLink.tag();
        if (!this.pathDict.containsKey(str2)) {
            this.pathDict.put(str2, Integer.valueOf(this.pathDict.size()));
        }
        sparseVector.setValue(this.pathDict.get(str2).intValue(), sparseVector.getValue(this.pathDict.get(str2).intValue()) + 1.0d);
        if (i > 0) {
            countPathRec(sparseVector, dTLink.to(), str2, i - 1);
        }
    }

    private void init(DTGraph<String, String> dTGraph, List<DTNode<String, String>> list) {
        this.rdfGraph = new LightDTGraph();
        this.instanceVertices = new ArrayList();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < list.size(); i++) {
            hashMap.put(list.get(i), Integer.valueOf(i));
            this.instanceVertices.add(null);
        }
        LightDTGraph lightDTGraph = new LightDTGraph();
        for (DTNode<String, String> dTNode : dTGraph.nodes()) {
            if (!this.labelDict.containsKey(dTNode.label())) {
                this.labelDict.put(dTNode.label(), Integer.valueOf(this.labelDict.size()));
            }
            String str = "_" + Integer.toString(this.labelDict.get(dTNode.label()).intValue());
            if (hashMap.containsKey(dTNode)) {
                this.instanceVertices.set(((Integer) hashMap.get(dTNode)).intValue(), lightDTGraph.add((LightDTGraph) str));
            } else {
                lightDTGraph.add((LightDTGraph) str);
            }
        }
        for (DTLink<String, String> dTLink : dTGraph.links()) {
            if (!this.labelDict.containsKey(dTLink.tag())) {
                this.labelDict.put(dTLink.tag(), Integer.valueOf(this.labelDict.size()));
            }
            ((DTNode) lightDTGraph.nodes().get(dTLink.from().index())).connect((TNode<L, TNode>) lightDTGraph.nodes().get(dTLink.to().index()), (TNode) ("_" + Integer.toString(this.labelDict.get(dTLink.tag()).intValue())));
        }
    }

    @Override // org.data2semantics.mustard.kernels.FeatureInspector
    public List<String> getFeatureDescriptions(List<Integer> list) {
        if (this.labelDict == null) {
            throw new RuntimeException("Should run computeFeatureVectors first");
        }
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(WalkCountUtils.getFeatureDecription(this.reverseLabelDict, this.reversePathDict, it.next().intValue()));
        }
        return arrayList;
    }
}
