package org.data2semantics.mustard.kernels.graphkernels.singledtgraph;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.data2semantics.mustard.kernels.ComputationTimeTracker;
import org.data2semantics.mustard.kernels.FeatureInspector;
import org.data2semantics.mustard.kernels.KernelUtils;
import org.data2semantics.mustard.kernels.SparseVector;
import org.data2semantics.mustard.kernels.data.SingleDTGraph;
import org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel;
import org.data2semantics.mustard.kernels.graphkernels.GraphKernel;
import org.data2semantics.mustard.weisfeilerlehman.StringLabel;
import org.data2semantics.mustard.weisfeilerlehman.WLUtils;
import org.data2semantics.mustard.weisfeilerlehman.WeisfeilerLehmanDTGraphIterator;
import org.nodes.DTGraph;
import org.nodes.DTLink;
import org.nodes.DTNode;
import org.nodes.LightDTGraph;
import org.nodes.TNode;

/* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/singledtgraph/DTGraphWLSubTreeIDEQKernel.class */
public class DTGraphWLSubTreeIDEQKernel implements GraphKernel<SingleDTGraph>, FeatureVectorKernel<SingleDTGraph>, ComputationTimeTracker, FeatureInspector {
    private Map<DTNode<StringLabel, StringLabel>, Map<DTNode<StringLabel, StringLabel>, Integer>> instanceVertexIndexMap;
    private Map<DTNode<StringLabel, StringLabel>, Map<DTLink<StringLabel, StringLabel>, Integer>> instanceEdgeIndexMap;
    private Map<DTNode<StringLabel, StringLabel>, Map<DTNode<StringLabel, StringLabel>, Boolean>> instanceVertexIgnoreMap;
    private Map<DTNode<StringLabel, StringLabel>, Map<DTLink<StringLabel, StringLabel>, Boolean>> instanceEdgeIgnoreMap;
    private DTGraph<StringLabel, StringLabel> rdfGraph;
    private List<DTNode<StringLabel, StringLabel>> instanceVertices;
    private int depth;
    private int iterations;
    private boolean normalize;
    private boolean reverse;
    private boolean iterationWeighting;
    private boolean noDuplicateNBH;
    private boolean noSubGraphs;
    private long compTime;
    private Map<String, String> dict;

    public DTGraphWLSubTreeIDEQKernel(int i, int i2, boolean z, boolean z2, boolean z3, boolean z4, boolean z5) {
        this.reverse = z;
        this.iterationWeighting = z2;
        this.noDuplicateNBH = z3;
        this.noSubGraphs = z4;
        this.normalize = z5;
        this.depth = i2;
        this.iterations = i;
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public String getLabel() {
        return KernelUtils.createLabel(this);
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    @Override // org.data2semantics.mustard.kernels.ComputationTimeTracker
    public long getComputationTime() {
        return this.compTime;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel
    public SparseVector[] computeFeatureVectors(SingleDTGraph singleDTGraph) {
        SparseVector[] sparseVectorArr = new SparseVector[singleDTGraph.getInstances().size()];
        for (int i = 0; i < sparseVectorArr.length; i++) {
            sparseVectorArr[i] = new SparseVector();
        }
        init(singleDTGraph.getGraph(), singleDTGraph.getInstances());
        WeisfeilerLehmanDTGraphIterator weisfeilerLehmanDTGraphIterator = new WeisfeilerLehmanDTGraphIterator(this.reverse, this.noDuplicateNBH);
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.rdfGraph);
        long currentTimeMillis = System.currentTimeMillis();
        weisfeilerLehmanDTGraphIterator.wlInitialize(arrayList);
        double sqrt = this.iterationWeighting ? Math.sqrt(1.0d / (this.iterations + 1)) : 1.0d;
        computeFVs(this.rdfGraph, this.instanceVertices, sqrt, sparseVectorArr, weisfeilerLehmanDTGraphIterator.getLabelDict().size() - 1, 0);
        for (int i2 = 0; i2 < this.iterations; i2++) {
            if (this.iterationWeighting) {
                sqrt = Math.sqrt((2.0d + i2) / (this.iterations + 1));
            }
            weisfeilerLehmanDTGraphIterator.wlIterate(arrayList);
            computeFVs(this.rdfGraph, this.instanceVertices, sqrt, sparseVectorArr, weisfeilerLehmanDTGraphIterator.getLabelDict().size() - 1, i2 + 1);
        }
        this.compTime = System.currentTimeMillis() - currentTimeMillis;
        this.dict = new HashMap();
        for (String str : weisfeilerLehmanDTGraphIterator.getLabelDict().keySet()) {
            this.dict.put(weisfeilerLehmanDTGraphIterator.getLabelDict().get(str), str);
        }
        if (this.normalize) {
            sparseVectorArr = KernelUtils.normalize(sparseVectorArr);
        }
        return sparseVectorArr;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.GraphKernel
    public double[][] compute(SingleDTGraph singleDTGraph) {
        SparseVector[] computeFeatureVectors = computeFeatureVectors(singleDTGraph);
        double[][] initMatrix = KernelUtils.initMatrix(singleDTGraph.getInstances().size(), singleDTGraph.getInstances().size());
        long currentTimeMillis = System.currentTimeMillis();
        double[][] computeKernelMatrix = KernelUtils.computeKernelMatrix(computeFeatureVectors, initMatrix);
        this.compTime += System.currentTimeMillis() - currentTimeMillis;
        return computeKernelMatrix;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void init(DTGraph<String, String> dTGraph, List<DTNode<String, String>> list) {
        DTNode<StringLabel, StringLabel> add;
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        this.rdfGraph = new LightDTGraph();
        this.instanceVertices = new ArrayList();
        this.instanceVertexIndexMap = new HashMap();
        this.instanceEdgeIndexMap = new HashMap();
        this.instanceVertexIgnoreMap = new HashMap();
        this.instanceEdgeIgnoreMap = new HashMap();
        for (DTNode<String, String> dTNode : list) {
            HashMap hashMap3 = new HashMap();
            HashMap hashMap4 = new HashMap();
            HashMap hashMap5 = new HashMap();
            HashMap hashMap6 = new HashMap();
            if (hashMap.containsKey(dTNode)) {
                add = (DTNode) hashMap.get(dTNode);
            } else {
                add = this.rdfGraph.add((DTGraph<StringLabel, StringLabel>) new StringLabel());
                hashMap.put(dTNode, add);
            }
            add.label().clear();
            add.label().append(dTNode.label());
            this.instanceVertices.add(add);
            this.instanceVertexIndexMap.put(add, hashMap3);
            this.instanceEdgeIndexMap.put(add, hashMap4);
            this.instanceVertexIgnoreMap.put(add, hashMap5);
            this.instanceEdgeIgnoreMap.put(add, hashMap6);
            ArrayList<DTNode> arrayList = new ArrayList();
            arrayList.add(dTNode);
            hashMap3.put(add, Integer.valueOf(this.depth));
            hashMap5.put(add, false);
            for (int i = this.depth - 1; i >= 0; i--) {
                ArrayList arrayList2 = new ArrayList();
                for (DTNode dTNode2 : arrayList) {
                    for (DTLink dTLink : dTNode2.linksOut()) {
                        if (hashMap.containsKey(dTLink.to())) {
                            if (!hashMap3.containsKey(hashMap.get(dTLink.to())) || !this.reverse) {
                                hashMap3.put((DTNode) hashMap.get(dTLink.to()), Integer.valueOf(i));
                                hashMap5.put((DTNode) hashMap.get(dTLink.to()), false);
                            }
                            ((StringLabel) ((DTNode) hashMap.get(dTLink.to())).label()).clear();
                            ((StringLabel) ((DTNode) hashMap.get(dTLink.to())).label()).append((String) dTLink.to().label());
                        } else {
                            DTNode<StringLabel, StringLabel> add2 = this.rdfGraph.add((DTGraph<StringLabel, StringLabel>) new StringLabel());
                            add2.label().clear();
                            add2.label().append((String) dTLink.to().label());
                            hashMap.put(dTLink.to(), add2);
                            hashMap3.put(add2, Integer.valueOf(i));
                            hashMap5.put(add2, false);
                        }
                        if (hashMap2.containsKey(dTLink)) {
                            if (!hashMap4.containsKey(hashMap2.get(dTLink)) || !this.reverse) {
                                hashMap4.put((DTLink) hashMap2.get(dTLink), Integer.valueOf(i));
                                hashMap6.put((DTLink) hashMap2.get(dTLink), false);
                            }
                            ((StringLabel) ((DTLink) hashMap2.get(dTLink)).tag()).clear();
                            ((StringLabel) ((DTLink) hashMap2.get(dTLink)).tag()).append((String) dTLink.tag());
                        } else {
                            DTLink connect = ((DTNode) hashMap.get(dTNode2)).connect((TNode<L, TNode>) hashMap.get(dTLink.to()), (TNode) new StringLabel());
                            ((StringLabel) connect.tag()).clear();
                            ((StringLabel) connect.tag()).append((String) dTLink.tag());
                            hashMap2.put(dTLink, connect);
                            hashMap4.put(connect, Integer.valueOf(i));
                            hashMap6.put(connect, false);
                        }
                        if (i > 0) {
                            arrayList2.add(dTLink.to());
                        }
                    }
                }
                arrayList = arrayList2;
            }
        }
    }

    private void computeFVs(DTGraph<StringLabel, StringLabel> dTGraph, List<DTNode<StringLabel, StringLabel>> list, double d, SparseVector[] sparseVectorArr, int i, int i2) {
        for (int i3 = 0; i3 < list.size(); i3++) {
            sparseVectorArr[i3].setLastIndex(i);
            Map<DTNode<StringLabel, StringLabel>, Integer> map = this.instanceVertexIndexMap.get(list.get(i3));
            Map<DTNode<StringLabel, StringLabel>, Boolean> map2 = this.instanceVertexIgnoreMap.get(list.get(i3));
            Map<DTLink<StringLabel, StringLabel>, Integer> map3 = this.instanceEdgeIndexMap.get(list.get(i3));
            Map<DTLink<StringLabel, StringLabel>, Boolean> map4 = this.instanceEdgeIgnoreMap.get(list.get(i3));
            for (DTNode<StringLabel, StringLabel> dTNode : map.keySet()) {
                if ((map.get(dTNode).intValue() * 2) + 1 == i2 && !this.noSubGraphs) {
                    map2.put(dTNode, true);
                }
                if (!this.noDuplicateNBH || !dTNode.label().isSameAsPrev()) {
                    if (!map2.get(dTNode).booleanValue()) {
                        int parseInt = Integer.parseInt(dTNode.label().toString());
                        sparseVectorArr[i3].setValue(parseInt, sparseVectorArr[i3].getValue(parseInt) + d);
                    }
                }
            }
            for (DTLink<StringLabel, StringLabel> dTLink : map3.keySet()) {
                if ((map3.get(dTLink).intValue() * 2) + 2 == i2 && !this.noSubGraphs) {
                    map4.put(dTLink, true);
                }
                if (!this.noDuplicateNBH || !dTLink.tag().isSameAsPrev()) {
                    if (!map4.get(dTLink).booleanValue()) {
                        int parseInt2 = Integer.parseInt(dTLink.tag().toString());
                        sparseVectorArr[i3].setValue(parseInt2, sparseVectorArr[i3].getValue(parseInt2) + d);
                    }
                }
            }
        }
    }

    @Override // org.data2semantics.mustard.kernels.FeatureInspector
    public List<String> getFeatureDescriptions(List<Integer> list) {
        if (this.dict == null) {
            throw new RuntimeException("Should run computeFeatureVectors first");
        }
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(WLUtils.getFeatureDecription(this.dict, it.next().intValue()));
        }
        return arrayList;
    }
}
