package org.data2semantics.mustard.kernels.graphkernels.singledtgraph;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.data2semantics.mustard.kernels.ComputationTimeTracker;
import org.data2semantics.mustard.kernels.KernelUtils;
import org.data2semantics.mustard.kernels.SparseVector;
import org.data2semantics.mustard.kernels.data.SingleDTGraph;
import org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel;
import org.data2semantics.mustard.kernels.graphkernels.GraphKernel;
import org.data2semantics.mustard.utils.Pair;
import org.nodes.DTGraph;
import org.nodes.DTLink;
import org.nodes.DTNode;
import org.nodes.LightDTGraph;
import org.nodes.TNode;

/* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/singledtgraph/DTGraphTreeWalkCountKernelMkII.class */
public class DTGraphTreeWalkCountKernelMkII implements GraphKernel<SingleDTGraph>, FeatureVectorKernel<SingleDTGraph>, ComputationTimeTracker {
    private Map<DTNode<PathStringMapLabel, PathStringMapLabel>, List<Pair<DTNode<PathStringMapLabel, PathStringMapLabel>, Integer>>> instanceVertexIndexMap;
    private Map<DTNode<PathStringMapLabel, PathStringMapLabel>, List<Pair<DTLink<PathStringMapLabel, PathStringMapLabel>, Integer>>> instanceEdgeIndexMap;
    private DTGraph<PathStringMapLabel, PathStringMapLabel> rdfGraph;
    private List<DTNode<PathStringMapLabel, PathStringMapLabel>> instanceVertices;
    private int depth;
    private int pathLength;
    private boolean normalize;
    private long compTime;
    private Map<String, Integer> pathDict;
    private Map<String, Integer> labelDict;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/singledtgraph/DTGraphTreeWalkCountKernelMkII$PathStringMapLabel.class */
    public class PathStringMapLabel {
        private String label;
        private Map<Integer, List<String>> pathsMap = new HashMap();
        private Map<Integer, List<String>> newPathsMap = new HashMap();

        public PathStringMapLabel(String str) {
            this.label = str;
        }

        public void initDepth(int i) {
            this.pathsMap.put(Integer.valueOf(i), new ArrayList());
            this.newPathsMap.put(Integer.valueOf(i), new ArrayList());
            this.pathsMap.get(Integer.valueOf(i)).add(this.label);
        }

        public Map<Integer, List<String>> getPathsMap() {
            return this.pathsMap;
        }

        public void addPaths(List<String> list, int i) {
            Iterator<String> it = list.iterator();
            while (it.hasNext()) {
                this.newPathsMap.get(Integer.valueOf(i)).add(String.valueOf(this.label) + it.next());
            }
        }

        public void setNewPaths() {
            Iterator<Integer> it = this.pathsMap.keySet().iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                this.pathsMap.get(Integer.valueOf(intValue)).clear();
                this.pathsMap.get(Integer.valueOf(intValue)).addAll(this.newPathsMap.get(Integer.valueOf(intValue)));
                this.newPathsMap.get(Integer.valueOf(intValue)).clear();
            }
        }
    }

    public DTGraphTreeWalkCountKernelMkII(int i, int i2, boolean z) {
        this.normalize = z;
        this.depth = i2;
        this.pathLength = i;
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public String getLabel() {
        return KernelUtils.createLabel(this);
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    @Override // org.data2semantics.mustard.kernels.ComputationTimeTracker
    public long getComputationTime() {
        return this.compTime;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel
    public SparseVector[] computeFeatureVectors(SingleDTGraph singleDTGraph) {
        this.instanceVertices = new ArrayList();
        this.instanceVertexIndexMap = new HashMap();
        this.instanceEdgeIndexMap = new HashMap();
        this.pathDict = new HashMap();
        this.labelDict = new HashMap();
        init(singleDTGraph.getGraph(), singleDTGraph.getInstances());
        SparseVector[] sparseVectorArr = new SparseVector[singleDTGraph.numInstances()];
        for (int i = 0; i < sparseVectorArr.length; i++) {
            sparseVectorArr[i] = new SparseVector();
        }
        long currentTimeMillis = System.currentTimeMillis();
        for (DTNode<PathStringMapLabel, PathStringMapLabel> dTNode : this.rdfGraph.nodes()) {
            Iterator<Integer> it = dTNode.label().getPathsMap().keySet().iterator();
            while (it.hasNext()) {
                for (String str : dTNode.label().getPathsMap().get(Integer.valueOf(it.next().intValue()))) {
                    if (this.pathDict.get(str) == null) {
                        this.pathDict.put(str, Integer.valueOf(this.pathDict.size()));
                    }
                }
            }
        }
        for (DTLink<PathStringMapLabel, PathStringMapLabel> dTLink : this.rdfGraph.links()) {
            Iterator<Integer> it2 = dTLink.tag().getPathsMap().keySet().iterator();
            while (it2.hasNext()) {
                for (String str2 : dTLink.tag().getPathsMap().get(Integer.valueOf(it2.next().intValue()))) {
                    if (this.pathDict.get(str2) == null) {
                        this.pathDict.put(str2, Integer.valueOf(this.pathDict.size()));
                    }
                }
            }
        }
        computeFVs(this.rdfGraph, this.instanceVertices, sparseVectorArr, this.pathDict.size() - 1);
        for (int i2 = 0; i2 < this.pathLength; i2++) {
            for (DTNode<PathStringMapLabel, PathStringMapLabel> dTNode2 : this.rdfGraph.nodes()) {
                for (DTLink<PathStringMapLabel, PathStringMapLabel> dTLink2 : dTNode2.linksOut()) {
                    Iterator<Integer> it3 = dTNode2.label().getPathsMap().keySet().iterator();
                    while (it3.hasNext()) {
                        int intValue = it3.next().intValue();
                        if (intValue > 0) {
                            dTNode2.label().addPaths(dTLink2.tag().getPathsMap().get(Integer.valueOf(intValue - 1)), intValue);
                        }
                    }
                }
            }
            for (DTLink<PathStringMapLabel, PathStringMapLabel> dTLink3 : this.rdfGraph.links()) {
                Iterator<Integer> it4 = dTLink3.tag().getPathsMap().keySet().iterator();
                while (it4.hasNext()) {
                    int intValue2 = it4.next().intValue();
                    dTLink3.tag().addPaths(dTLink3.to().label().getPathsMap().get(Integer.valueOf(intValue2)), intValue2);
                }
            }
            for (DTNode<PathStringMapLabel, PathStringMapLabel> dTNode3 : this.rdfGraph.nodes()) {
                dTNode3.label().setNewPaths();
                Iterator<Integer> it5 = dTNode3.label().getPathsMap().keySet().iterator();
                while (it5.hasNext()) {
                    for (String str3 : dTNode3.label().getPathsMap().get(Integer.valueOf(it5.next().intValue()))) {
                        if (this.pathDict.get(str3) == null) {
                            this.pathDict.put(str3, Integer.valueOf(this.pathDict.size()));
                        }
                    }
                }
            }
            for (DTLink<PathStringMapLabel, PathStringMapLabel> dTLink4 : this.rdfGraph.links()) {
                dTLink4.tag().setNewPaths();
                Iterator<Integer> it6 = dTLink4.tag().getPathsMap().keySet().iterator();
                while (it6.hasNext()) {
                    for (String str4 : dTLink4.tag().getPathsMap().get(Integer.valueOf(it6.next().intValue()))) {
                        if (this.pathDict.get(str4) == null) {
                            this.pathDict.put(str4, Integer.valueOf(this.pathDict.size()));
                        }
                    }
                }
            }
            computeFVs(this.rdfGraph, this.instanceVertices, sparseVectorArr, this.pathDict.size() - 1);
        }
        this.compTime = System.currentTimeMillis() - currentTimeMillis;
        if (this.normalize) {
            sparseVectorArr = KernelUtils.normalize(sparseVectorArr);
        }
        return sparseVectorArr;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.GraphKernel
    public double[][] compute(SingleDTGraph singleDTGraph) {
        SparseVector[] computeFeatureVectors = computeFeatureVectors(singleDTGraph);
        double[][] initMatrix = KernelUtils.initMatrix(singleDTGraph.getInstances().size(), singleDTGraph.getInstances().size());
        long currentTimeMillis = System.currentTimeMillis();
        double[][] computeKernelMatrix = KernelUtils.computeKernelMatrix(computeFeatureVectors, initMatrix);
        this.compTime += System.currentTimeMillis() - currentTimeMillis;
        return computeKernelMatrix;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void init(DTGraph<String, String> dTGraph, List<DTNode<String, String>> list) {
        DTNode<PathStringMapLabel, PathStringMapLabel> add;
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        this.rdfGraph = new LightDTGraph();
        for (DTNode<String, String> dTNode : list) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            if (hashMap.containsKey(dTNode)) {
                add = (DTNode) hashMap.get(dTNode);
            } else {
                if (!this.labelDict.containsKey(dTNode.label())) {
                    this.labelDict.put(dTNode.label(), Integer.valueOf(this.labelDict.size()));
                }
                add = this.rdfGraph.add((DTGraph<PathStringMapLabel, PathStringMapLabel>) new PathStringMapLabel("_" + Integer.toString(this.labelDict.get(dTNode.label()).intValue())));
                hashMap.put(dTNode, add);
            }
            add.label().initDepth(this.depth);
            this.instanceVertices.add(add);
            this.instanceVertexIndexMap.put(add, arrayList);
            this.instanceEdgeIndexMap.put(add, arrayList2);
            ArrayList<DTNode> arrayList3 = new ArrayList();
            arrayList3.add(dTNode);
            arrayList.add(new Pair(add, Integer.valueOf(this.depth)));
            for (int i = this.depth - 1; i >= 0; i--) {
                ArrayList arrayList4 = new ArrayList();
                for (DTNode dTNode2 : arrayList3) {
                    for (DTLink dTLink : dTNode2.linksOut()) {
                        if (hashMap.containsKey(dTLink.to())) {
                            arrayList.add(new Pair((DTNode) hashMap.get(dTLink.to()), Integer.valueOf(i)));
                            ((PathStringMapLabel) ((DTNode) hashMap.get(dTLink.to())).label()).initDepth(i);
                        } else {
                            if (!this.labelDict.containsKey(dTLink.to().label())) {
                                this.labelDict.put((String) dTLink.to().label(), Integer.valueOf(this.labelDict.size()));
                            }
                            DTNode<PathStringMapLabel, PathStringMapLabel> add2 = this.rdfGraph.add((DTGraph<PathStringMapLabel, PathStringMapLabel>) new PathStringMapLabel("_" + Integer.toString(this.labelDict.get(dTLink.to().label()).intValue())));
                            add2.label().initDepth(i);
                            hashMap.put(dTLink.to(), add2);
                            arrayList.add(new Pair(add2, Integer.valueOf(i)));
                        }
                        if (hashMap2.containsKey(dTLink)) {
                            arrayList2.add(new Pair((DTLink) hashMap2.get(dTLink), Integer.valueOf(i)));
                            ((PathStringMapLabel) ((DTLink) hashMap2.get(dTLink)).tag()).initDepth(i);
                        } else {
                            if (!this.labelDict.containsKey(dTLink.tag())) {
                                this.labelDict.put((String) dTLink.tag(), Integer.valueOf(this.labelDict.size()));
                            }
                            DTLink connect = ((DTNode) hashMap.get(dTNode2)).connect((TNode<L, TNode>) hashMap.get(dTLink.to()), (TNode) new PathStringMapLabel("_" + Integer.toString(this.labelDict.get(dTLink.tag()).intValue())));
                            ((PathStringMapLabel) connect.tag()).initDepth(i);
                            hashMap2.put(dTLink, connect);
                            arrayList2.add(new Pair(connect, Integer.valueOf(i)));
                        }
                        if (i > 0) {
                            arrayList4.add(dTLink.to());
                        }
                    }
                }
                arrayList3 = arrayList4;
            }
        }
    }

    private void computeFVs(DTGraph<PathStringMapLabel, PathStringMapLabel> dTGraph, List<DTNode<PathStringMapLabel, PathStringMapLabel>> list, SparseVector[] sparseVectorArr, int i) {
        for (int i2 = 0; i2 < list.size(); i2++) {
            sparseVectorArr[i2].setLastIndex(i);
            for (Pair<DTNode<PathStringMapLabel, PathStringMapLabel>, Integer> pair : this.instanceVertexIndexMap.get(list.get(i2))) {
                Iterator<String> it = pair.getFirst().label().getPathsMap().get(pair.getSecond()).iterator();
                while (it.hasNext()) {
                    int intValue = this.pathDict.get(it.next()).intValue();
                    sparseVectorArr[i2].setValue(intValue, sparseVectorArr[i2].getValue(intValue) + 1.0d);
                }
            }
            for (Pair<DTLink<PathStringMapLabel, PathStringMapLabel>, Integer> pair2 : this.instanceEdgeIndexMap.get(list.get(i2))) {
                Iterator<String> it2 = pair2.getFirst().tag().getPathsMap().get(pair2.getSecond()).iterator();
                while (it2.hasNext()) {
                    int intValue2 = this.pathDict.get(it2.next()).intValue();
                    sparseVectorArr[i2].setValue(intValue2, sparseVectorArr[i2].getValue(intValue2) + 1.0d);
                }
            }
        }
    }
}
