package org.data2semantics.mustard.kernels.graphkernels.graphlist;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.data2semantics.mustard.kernels.ComputationTimeTracker;
import org.data2semantics.mustard.kernels.KernelUtils;
import org.data2semantics.mustard.kernels.SparseVector;
import org.data2semantics.mustard.kernels.data.GraphList;
import org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel;
import org.data2semantics.mustard.kernels.graphkernels.GraphKernel;
import org.nodes.DTGraph;
import org.nodes.DTLink;
import org.nodes.DTNode;
import org.nodes.LightDTGraph;
import org.nodes.TNode;

/* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/graphlist/WalkCountKernelMkII.class */
public class WalkCountKernelMkII implements GraphKernel<GraphList<DTGraph<String, String>>>, FeatureVectorKernel<GraphList<DTGraph<String, String>>>, ComputationTimeTracker {
    private int depth;
    protected boolean normalize;
    private long compTime;
    private Map<String, Integer> pathDict;
    private Map<String, Integer> labelDict;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/graphlist/WalkCountKernelMkII$PathStringLabel.class */
    public class PathStringLabel {
        private String label;
        private List<String> paths = new ArrayList();
        private List<String> newPaths;

        public PathStringLabel(String str) {
            this.label = str;
            this.paths.add(new String(str));
            this.newPaths = new ArrayList();
        }

        public List<String> getPaths() {
            return this.paths;
        }

        public void addPaths(List<String> list) {
            Iterator<String> it = list.iterator();
            while (it.hasNext()) {
                this.newPaths.add(String.valueOf(this.label) + it.next());
            }
        }

        public void setNewPaths() {
            this.paths.clear();
            this.paths.addAll(this.newPaths);
            this.newPaths.clear();
        }
    }

    public WalkCountKernelMkII(int i, boolean z) {
        this.normalize = z;
        this.depth = i;
    }

    public WalkCountKernelMkII(int i) {
        this(i, true);
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public String getLabel() {
        return KernelUtils.createLabel(this);
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    @Override // org.data2semantics.mustard.kernels.ComputationTimeTracker
    public long getComputationTime() {
        return this.compTime;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel
    public SparseVector[] computeFeatureVectors(GraphList<DTGraph<String, String>> graphList) {
        this.pathDict = new HashMap();
        this.labelDict = new HashMap();
        List<DTGraph<PathStringLabel, PathStringLabel>> copyGraphs = copyGraphs(graphList.getGraphs());
        SparseVector[] sparseVectorArr = new SparseVector[copyGraphs.size()];
        for (int i = 0; i < sparseVectorArr.length; i++) {
            sparseVectorArr[i] = new SparseVector();
        }
        long currentTimeMillis = System.currentTimeMillis();
        for (int i2 = 0; i2 < sparseVectorArr.length; i2++) {
            Iterator<? extends DTNode<PathStringLabel, PathStringLabel>> it = copyGraphs.get(i2).nodes().iterator();
            while (it.hasNext()) {
                for (String str : it.next().label().getPaths()) {
                    Integer num = this.pathDict.get(str);
                    if (num == null) {
                        num = Integer.valueOf(this.pathDict.size());
                        this.pathDict.put(str, num);
                    }
                    sparseVectorArr[i2].setValue(num.intValue(), sparseVectorArr[i2].getValue(num.intValue()) + 1.0d);
                }
            }
            Iterator<? extends DTLink<PathStringLabel, PathStringLabel>> it2 = copyGraphs.get(i2).links().iterator();
            while (it2.hasNext()) {
                for (String str2 : it2.next().tag().getPaths()) {
                    Integer num2 = this.pathDict.get(str2);
                    if (num2 == null) {
                        num2 = Integer.valueOf(this.pathDict.size());
                        this.pathDict.put(str2, num2);
                    }
                    sparseVectorArr[i2].setValue(num2.intValue(), sparseVectorArr[i2].getValue(num2.intValue()) + 1.0d);
                }
            }
            for (int i3 = 0; i3 < this.depth; i3++) {
                for (DTNode<PathStringLabel, PathStringLabel> dTNode : copyGraphs.get(i2).nodes()) {
                    Iterator<? extends DTLink<PathStringLabel, PathStringLabel>> it3 = dTNode.linksOut().iterator();
                    while (it3.hasNext()) {
                        dTNode.label().addPaths(it3.next().tag().getPaths());
                    }
                }
                for (DTLink<PathStringLabel, PathStringLabel> dTLink : copyGraphs.get(i2).links()) {
                    dTLink.tag().addPaths(dTLink.to().label().getPaths());
                }
                for (DTNode<PathStringLabel, PathStringLabel> dTNode2 : copyGraphs.get(i2).nodes()) {
                    dTNode2.label().setNewPaths();
                    for (String str3 : dTNode2.label().getPaths()) {
                        Integer num3 = this.pathDict.get(str3);
                        if (num3 == null) {
                            num3 = Integer.valueOf(this.pathDict.size());
                            this.pathDict.put(str3, num3);
                        }
                        sparseVectorArr[i2].setValue(num3.intValue(), sparseVectorArr[i2].getValue(num3.intValue()) + 1.0d);
                    }
                }
                for (DTLink<PathStringLabel, PathStringLabel> dTLink2 : copyGraphs.get(i2).links()) {
                    dTLink2.tag().setNewPaths();
                    for (String str4 : dTLink2.tag().getPaths()) {
                        Integer num4 = this.pathDict.get(str4);
                        if (num4 == null) {
                            num4 = Integer.valueOf(this.pathDict.size());
                            this.pathDict.put(str4, num4);
                        }
                        sparseVectorArr[i2].setValue(num4.intValue(), sparseVectorArr[i2].getValue(num4.intValue()) + 1.0d);
                    }
                }
            }
        }
        for (SparseVector sparseVector : sparseVectorArr) {
            sparseVector.setLastIndex(this.pathDict.size() - 1);
        }
        this.compTime = System.currentTimeMillis() - currentTimeMillis;
        if (this.normalize) {
            sparseVectorArr = KernelUtils.normalize(sparseVectorArr);
        }
        return sparseVectorArr;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.GraphKernel
    public double[][] compute(GraphList<DTGraph<String, String>> graphList) {
        double[][] initMatrix = KernelUtils.initMatrix(graphList.getGraphs().size(), graphList.getGraphs().size());
        KernelUtils.computeKernelMatrix(computeFeatureVectors(graphList), initMatrix);
        return initMatrix;
    }

    private List<DTGraph<PathStringLabel, PathStringLabel>> copyGraphs(List<DTGraph<String, String>> list) {
        ArrayList arrayList = new ArrayList();
        for (DTGraph<String, String> dTGraph : list) {
            LightDTGraph lightDTGraph = new LightDTGraph();
            for (DTNode<String, String> dTNode : dTGraph.nodes()) {
                if (!this.labelDict.containsKey(dTNode.label())) {
                    this.labelDict.put(dTNode.label(), Integer.valueOf(this.labelDict.size()));
                }
                lightDTGraph.add((LightDTGraph) new PathStringLabel("_" + Integer.toString(this.labelDict.get(dTNode.label()).intValue())));
            }
            for (DTLink<String, String> dTLink : dTGraph.links()) {
                if (!this.labelDict.containsKey(dTLink.tag())) {
                    this.labelDict.put(dTLink.tag(), Integer.valueOf(this.labelDict.size()));
                }
                ((DTNode) lightDTGraph.nodes().get(dTLink.from().index())).connect((TNode<L, TNode>) lightDTGraph.nodes().get(dTLink.to().index()), (TNode) new PathStringLabel("_" + Integer.toString(this.labelDict.get(dTLink.tag()).intValue())));
            }
            arrayList.add(lightDTGraph);
        }
        return arrayList;
    }
}
