package org.data2semantics.mustard.kernels.graphkernels.singledtgraph;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.TreeMap;
import java.util.TreeSet;
import org.data2semantics.mustard.kernels.ComputationTimeTracker;
import org.data2semantics.mustard.kernels.KernelUtils;
import org.data2semantics.mustard.kernels.data.SingleDTGraph;
import org.data2semantics.mustard.kernels.graphkernels.GraphKernel;
import org.nodes.DGraph;
import org.nodes.DNode;
import org.nodes.DTGraph;
import org.nodes.DTLink;
import org.nodes.DTNode;
import org.nodes.LightDGraph;
import org.nodes.LightDTGraph;
import org.nodes.Node;
import org.nodes.TNode;

/* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/singledtgraph/DTGraphIntersectionSubTreeKernel.class */
public class DTGraphIntersectionSubTreeKernel implements GraphKernel<SingleDTGraph>, ComputationTimeTracker {
    private int depth;
    private double discountFactor;
    private long compTime;
    protected boolean normalize;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/singledtgraph/DTGraphIntersectionSubTreeKernel$Pair.class */
    public class Pair implements Comparable<Pair> {
        int first;
        int second;

        public Pair(int i, int i2) {
            this.first = i;
            this.second = i2;
        }

        public int getFirst() {
            return this.first;
        }

        public int getSecond() {
            return this.second;
        }

        public boolean equals(Pair pair) {
            return this.first == pair.getFirst() && this.second == pair.getSecond();
        }

        @Override // java.lang.Comparable
        public int compareTo(Pair pair) {
            return this.first == pair.getFirst() ? this.second - pair.getSecond() : this.first - pair.getFirst();
        }

        public String toString() {
            return "(" + this.first + "," + this.second + ")";
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/singledtgraph/DTGraphIntersectionSubTreeKernel$Tree.class */
    public class Tree {
        private DGraph<String> graph = new LightDGraph();
        private DNode<String> root;

        public Tree() {
        }

        public DGraph<String> getGraph() {
            return this.graph;
        }

        public DNode<String> getRoot() {
            return this.root;
        }

        public void setRoot(DNode<String> dNode) {
            this.root = dNode;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/singledtgraph/DTGraphIntersectionSubTreeKernel$VertexTracker.class */
    public class VertexTracker {
        DTNode<String, String> vertex;
        int count;

        public VertexTracker(DTNode<String, String> dTNode, int i) {
            this.vertex = dTNode;
            this.count = i;
        }

        public DTNode<String, String> getVertex() {
            return this.vertex;
        }

        public void setVertex(DTNode<String, String> dTNode) {
            this.vertex = dTNode;
        }

        public int getCount() {
            return this.count;
        }

        public void setCount(int i) {
            this.count = i;
        }
    }

    public DTGraphIntersectionSubTreeKernel() {
        this(2, 1.0d, true);
    }

    public DTGraphIntersectionSubTreeKernel(int i, double d, boolean z) {
        this.normalize = z;
        this.depth = i;
        this.discountFactor = d;
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public String getLabel() {
        return KernelUtils.createLabel(this);
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    @Override // org.data2semantics.mustard.kernels.ComputationTimeTracker
    public long getComputationTime() {
        return this.compTime;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.GraphKernel
    public double[][] compute(SingleDTGraph singleDTGraph) {
        List<DTNode<String, String>> instances = singleDTGraph.getInstances();
        double[][] initMatrix = KernelUtils.initMatrix(instances.size(), instances.size());
        DTGraph<String, String> intGraph = toIntGraph(singleDTGraph.getGraph(), instances);
        long currentTimeMillis = System.currentTimeMillis();
        for (int i = 0; i < instances.size(); i++) {
            for (int i2 = i; i2 < instances.size(); i2++) {
                initMatrix[i][i2] = subTreeScore(computeIntersectionTree(intGraph, instances.get(i), instances.get(i2)).getRoot(), this.discountFactor);
                initMatrix[i2][i] = initMatrix[i][i2];
            }
        }
        this.compTime = System.currentTimeMillis() - currentTimeMillis;
        return this.normalize ? KernelUtils.normalize(initMatrix) : initMatrix;
    }

    private DTGraph<String, String> toIntGraph(DTGraph<String, String> dTGraph, List<DTNode<String, String>> list) {
        LightDTGraph lightDTGraph = new LightDTGraph();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (int i = 0; i < list.size(); i++) {
            hashMap2.put(list.get(i), Integer.valueOf(i));
        }
        for (DTNode<String, String> dTNode : dTGraph.nodes()) {
            String str = (String) hashMap.get(dTNode.label());
            if (str == null) {
                str = Integer.toString(hashMap.size() + 1);
                hashMap.put(dTNode.label(), str);
            }
            DTNode<String, String> add = lightDTGraph.add((LightDTGraph) str);
            if (hashMap2.containsKey(dTNode)) {
                list.set(((Integer) hashMap2.get(dTNode)).intValue(), add);
            }
        }
        for (DTLink<String, String> dTLink : dTGraph.links()) {
            String str2 = (String) hashMap.get(dTLink.tag());
            if (str2 == null) {
                str2 = Integer.toString(hashMap.size() + 1);
                hashMap.put(dTLink.tag(), str2);
            }
            ((DTNode) lightDTGraph.nodes().get(dTLink.from().index())).connect((TNode<L, TNode>) lightDTGraph.nodes().get(dTLink.to().index()), (TNode) str2);
        }
        return lightDTGraph;
    }

    private Tree computeIntersectionTree(DTGraph<String, String> dTGraph, DTNode<String, String> dTNode, DTNode<String, String> dTNode2) {
        Tree tree = new Tree();
        HashMap hashMap = new HashMap();
        List<DTNode<String, String>> commonChilds = getCommonChilds(dTGraph, dTNode, dTNode2);
        int i = 1 + 1;
        VertexTracker vertexTracker = new VertexTracker(null, 1);
        hashMap.put(vertexTracker, tree.getGraph().add((DGraph<String>) "0"));
        tree.setRoot((DNode) hashMap.get(vertexTracker));
        for (int i2 = 0; i2 < this.depth; i2++) {
            HashMap hashMap2 = new HashMap();
            for (VertexTracker vertexTracker2 : hashMap.keySet()) {
                HashMap hashMap3 = new HashMap();
                if (vertexTracker2.getVertex() == null) {
                    Iterator<DTNode<String, String>> it = commonChilds.iterator();
                    while (it.hasNext()) {
                        DTNode<String, String> next = it.next();
                        int i3 = i;
                        i++;
                        hashMap3.put(new VertexTracker(next, i3), tree.getGraph().add((DGraph<String>) (next == null ? "0" : next.label())));
                    }
                } else {
                    for (DTLink<String, String> dTLink : vertexTracker2.getVertex().linksOut()) {
                        if (dTLink.to() == dTNode || dTLink.to() == dTNode2) {
                            int i4 = i;
                            i++;
                            hashMap3.put(new VertexTracker(null, i4), tree.getGraph().add((DGraph<String>) "0"));
                        } else {
                            int i5 = i;
                            i++;
                            hashMap3.put(new VertexTracker(dTLink.to(), i5), tree.getGraph().add((DGraph<String>) dTLink.to().label()));
                        }
                    }
                }
                Iterator it2 = hashMap3.keySet().iterator();
                while (it2.hasNext()) {
                    ((DNode) hashMap.get(vertexTracker2)).connect((Node) hashMap3.get((VertexTracker) it2.next()));
                }
                hashMap2.putAll(hashMap3);
            }
            hashMap = hashMap2;
        }
        return tree;
    }

    private List<DTNode<String, String>> getCommonChilds(DTGraph<String, String> dTGraph, DTNode<String, String> dTNode, DTNode<String, String> dTNode2) {
        ArrayList arrayList = new ArrayList();
        TreeSet<Pair> treeSet = new TreeSet();
        TreeSet treeSet2 = new TreeSet();
        TreeMap treeMap = new TreeMap();
        for (DTLink<String, String> dTLink : dTNode.linksOut()) {
            Pair pair = new Pair(Integer.parseInt(dTLink.tag().toString()), Integer.parseInt(dTLink.to().label().toString()));
            treeSet.add(pair);
            treeMap.put(pair, dTLink.to());
        }
        for (DTLink<String, String> dTLink2 : dTNode2.linksOut()) {
            Pair pair2 = new Pair(Integer.parseInt(dTLink2.tag().toString()), Integer.parseInt(dTLink2.to().label().toString()));
            treeSet2.add(pair2);
            treeMap.put(pair2, dTLink2.to());
        }
        for (Pair pair3 : treeSet) {
            if (pair3.getSecond() == Integer.parseInt(dTNode2.label().toString()) && treeSet2.contains(new Pair(pair3.getFirst(), Integer.parseInt(dTNode.label().toString())))) {
                arrayList.add(null);
            }
        }
        treeSet.retainAll(treeSet2);
        Iterator it = treeSet.iterator();
        while (it.hasNext()) {
            arrayList.add((DTNode) treeMap.get((Pair) it.next()));
        }
        return arrayList;
    }

    protected double subTreeScore(DNode<String> dNode, double d) {
        if (dNode.out().isEmpty()) {
            return 1.0d;
        }
        double d2 = 0.0d;
        Iterator<? extends DNode<String>> it = dNode.out().iterator();
        while (it.hasNext()) {
            d2 += subTreeScore(it.next(), d);
        }
        return 1.0d + (d * d2);
    }
}
