package org.data2semantics.mustard.kernels.graphkernels.singledtgraph;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.data2semantics.mustard.kernels.ComputationTimeTracker;
import org.data2semantics.mustard.kernels.KernelUtils;
import org.data2semantics.mustard.kernels.SparseVector;
import org.data2semantics.mustard.kernels.data.SingleDTGraph;
import org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel;
import org.data2semantics.mustard.kernels.graphkernels.GraphKernel;
import org.data2semantics.mustard.weisfeilerlehman.ApproxStringLabel;
import org.data2semantics.mustard.weisfeilerlehman.WeisfeilerLehmanApproxDTGraphIterator;
import org.nodes.DTGraph;
import org.nodes.DTLink;
import org.nodes.DTNode;
import org.nodes.LightDTGraph;
import org.nodes.TNode;

/* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/singledtgraph/DTGraphWLSubTreeGeoProbApproxKernel.class */
public class DTGraphWLSubTreeGeoProbApproxKernel implements GraphKernel<SingleDTGraph>, FeatureVectorKernel<SingleDTGraph>, ComputationTimeTracker {
    private DTGraph<ApproxStringLabel, ApproxStringLabel> rdfGraph;
    private List<DTNode<ApproxStringLabel, ApproxStringLabel>> instanceVertices;
    private int depth;
    private int iterations;
    private boolean normalize;
    private int[] maxPrevNBHs;
    private int[] maxLabelCards;
    private int[] minFreqs;
    private double p;
    private double mean;
    private Map<Integer, Double> probs;
    private double depthDiffWeight;
    private Map<String, Integer> labelFreq;
    private long compTime;

    public DTGraphWLSubTreeGeoProbApproxKernel(int i, int i2, double d, double d2, int[] iArr, int[] iArr2, int[] iArr3, boolean z) {
        this.normalize = z;
        this.depth = i2;
        this.iterations = i;
        this.maxPrevNBHs = iArr;
        this.maxLabelCards = iArr2;
        this.minFreqs = iArr3;
        this.mean = d;
        this.depthDiffWeight = d2;
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public String getLabel() {
        return KernelUtils.createLabel(this);
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    @Override // org.data2semantics.mustard.kernels.ComputationTimeTracker
    public long getComputationTime() {
        return this.compTime;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel
    public SparseVector[] computeFeatureVectors(SingleDTGraph singleDTGraph) {
        SparseVector[] sparseVectorArr = new SparseVector[singleDTGraph.numInstances()];
        for (int i = 0; i < sparseVectorArr.length; i++) {
            sparseVectorArr[i] = new SparseVector();
        }
        this.probs = new HashMap();
        this.p = 1.0d / (this.mean + 1.0d);
        WeisfeilerLehmanApproxDTGraphIterator weisfeilerLehmanApproxDTGraphIterator = new WeisfeilerLehmanApproxDTGraphIterator(true, 1, 1, 1);
        long currentTimeMillis = System.currentTimeMillis();
        for (int i2 : this.minFreqs) {
            for (int i3 : this.maxLabelCards) {
                for (int i4 : this.maxPrevNBHs) {
                    long currentTimeMillis2 = System.currentTimeMillis();
                    init(singleDTGraph.getGraph(), singleDTGraph.getInstances());
                    System.out.println("init comp: " + (System.currentTimeMillis() - currentTimeMillis2));
                    ArrayList arrayList = new ArrayList();
                    arrayList.add(this.rdfGraph);
                    weisfeilerLehmanApproxDTGraphIterator.wlInitialize(arrayList);
                    weisfeilerLehmanApproxDTGraphIterator.setMaxLabelCard(i3);
                    weisfeilerLehmanApproxDTGraphIterator.setMinFreq(i2);
                    weisfeilerLehmanApproxDTGraphIterator.setMaxPrevNBH(i4);
                    for (int i5 = 0; i5 < this.iterations; i5++) {
                        computeLabelFreqs(this.rdfGraph, this.instanceVertices);
                        weisfeilerLehmanApproxDTGraphIterator.wlIterate(arrayList, this.labelFreq);
                    }
                    long currentTimeMillis3 = System.currentTimeMillis();
                    computeFVs(this.rdfGraph, this.instanceVertices, sparseVectorArr, weisfeilerLehmanApproxDTGraphIterator.getLabelDict().size() - 1);
                    System.out.println("FV comp: " + (System.currentTimeMillis() - currentTimeMillis3));
                }
            }
        }
        this.compTime = System.currentTimeMillis() - currentTimeMillis;
        if (this.normalize) {
            sparseVectorArr = KernelUtils.normalize(sparseVectorArr);
        }
        return sparseVectorArr;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.GraphKernel
    public double[][] compute(SingleDTGraph singleDTGraph) {
        SparseVector[] computeFeatureVectors = computeFeatureVectors(singleDTGraph);
        double[][] initMatrix = KernelUtils.initMatrix(singleDTGraph.getInstances().size(), singleDTGraph.getInstances().size());
        long currentTimeMillis = System.currentTimeMillis();
        double[][] computeKernelMatrix = KernelUtils.computeKernelMatrix(computeFeatureVectors, initMatrix);
        this.compTime += System.currentTimeMillis() - currentTimeMillis;
        return computeKernelMatrix;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void init(DTGraph<String, String> dTGraph, List<DTNode<String, String>> list) {
        DTNode<ApproxStringLabel, ApproxStringLabel> add;
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        this.rdfGraph = new LightDTGraph();
        this.instanceVertices = new ArrayList();
        int i = 0;
        for (DTNode<String, String> dTNode : list) {
            if (hashMap.containsKey(dTNode)) {
                add = (DTNode) hashMap.get(dTNode);
            } else {
                add = this.rdfGraph.add((DTGraph<ApproxStringLabel, ApproxStringLabel>) new ApproxStringLabel());
                hashMap.put(dTNode, add);
            }
            add.label().clear();
            add.label().append(dTNode.label());
            add.label().addInstanceIndex(i);
            this.instanceVertices.add(add);
            ArrayList<DTNode> arrayList = new ArrayList();
            arrayList.add(dTNode);
            for (int i2 = this.depth - 1; i2 >= 0; i2--) {
                ArrayList arrayList2 = new ArrayList();
                for (DTNode dTNode2 : arrayList) {
                    for (DTLink dTLink : dTNode2.linksOut()) {
                        if (!hashMap.containsKey(dTLink.to())) {
                            DTNode<ApproxStringLabel, ApproxStringLabel> add2 = this.rdfGraph.add((DTGraph<ApproxStringLabel, ApproxStringLabel>) new ApproxStringLabel());
                            add2.label().clear();
                            add2.label().append((String) dTLink.to().label());
                            hashMap.put(dTLink.to(), add2);
                        }
                        ((ApproxStringLabel) ((DTNode) hashMap.get(dTLink.to())).label()).addInstanceIndex(i);
                        if (!hashMap2.containsKey(dTLink)) {
                            DTLink connect = ((DTNode) hashMap.get(dTNode2)).connect((TNode<L, TNode>) hashMap.get(dTLink.to()), (TNode) new ApproxStringLabel());
                            ((ApproxStringLabel) connect.tag()).clear();
                            ((ApproxStringLabel) connect.tag()).append((String) dTLink.tag());
                            hashMap2.put(dTLink, connect);
                        }
                        ((ApproxStringLabel) ((DTLink) hashMap2.get(dTLink)).tag()).addInstanceIndex(i);
                        if (i2 > 0) {
                            arrayList2.add(dTLink.to());
                        }
                    }
                }
                arrayList = arrayList2;
            }
            i++;
        }
    }

    private void computeLabelFreqs(DTGraph<ApproxStringLabel, ApproxStringLabel> dTGraph, List<DTNode<ApproxStringLabel, ApproxStringLabel>> list) {
        this.labelFreq = new HashMap();
        HashMap hashMap = new HashMap();
        for (DTNode<ApproxStringLabel, ApproxStringLabel> dTNode : this.rdfGraph.nodes()) {
            String approxStringLabel = dTNode.label().toString();
            if (!hashMap.containsKey(approxStringLabel)) {
                hashMap.put(approxStringLabel, new HashSet());
            }
            ((Set) hashMap.get(approxStringLabel)).addAll(dTNode.label().getInstanceIndexSet());
        }
        for (DTLink<ApproxStringLabel, ApproxStringLabel> dTLink : this.rdfGraph.links()) {
            String approxStringLabel2 = dTLink.tag().toString();
            if (!hashMap.containsKey(approxStringLabel2)) {
                hashMap.put(approxStringLabel2, new HashSet());
            }
            ((Set) hashMap.get(approxStringLabel2)).addAll(dTLink.tag().getInstanceIndexSet());
        }
        for (String str : hashMap.keySet()) {
            this.labelFreq.put(str, Integer.valueOf(((Set) hashMap.get(str)).size()));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void computeFVs(DTGraph<ApproxStringLabel, ApproxStringLabel> dTGraph, List<DTNode<ApproxStringLabel, ApproxStringLabel>> list, SparseVector[] sparseVectorArr, int i) {
        for (int i2 = 0; i2 < list.size(); i2++) {
            HashSet hashSet = new HashSet();
            HashSet hashSet2 = new HashSet();
            sparseVectorArr[i2].setLastIndex((i * (this.depth + 1)) + this.depth);
            setFV(sparseVectorArr[i2], list.get(i2).label().getIterations(), 0);
            ArrayList arrayList = new ArrayList();
            arrayList.add(list.get(i2));
            hashSet.add(list.get(i2));
            for (int i3 = 1; i3 <= this.depth; i3++) {
                ArrayList arrayList2 = new ArrayList();
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    for (DTLink dTLink : ((DTNode) it.next()).linksOut()) {
                        if (!hashSet2.contains(dTLink)) {
                            setFV(sparseVectorArr[i2], ((ApproxStringLabel) dTLink.tag()).getIterations(), (i3 * 2) - 1);
                            hashSet2.add(dTLink);
                        }
                        if (!hashSet.contains(dTLink.to())) {
                            setFV(sparseVectorArr[i2], ((ApproxStringLabel) dTLink.to().label()).getIterations(), i3 * 2);
                            hashSet.add(dTLink.to());
                            if (i3 < this.depth) {
                                arrayList2.add(dTLink.to());
                            }
                        }
                    }
                }
                arrayList = arrayList2;
            }
        }
    }

    private void setFV(SparseVector sparseVector, List<String> list, int i) {
        int i2 = 0;
        HashSet hashSet = new HashSet();
        for (String str : list) {
            if (i + i2 > this.iterations) {
                return;
            }
            if (!str.equals("") && !hashSet.contains(str)) {
                int parseInt = Integer.parseInt(str);
                double prob = getProb(i + i2);
                for (int i3 = 0; i3 <= this.depth; i3++) {
                    int i4 = (parseInt * (this.depth + 1)) + i3;
                    sparseVector.setValue(i4, sparseVector.getValue(i4) + (prob / Math.pow(this.depthDiffWeight, Math.abs(i3 - (i / 2.0d)))));
                }
                hashSet.add(str);
            }
            i2++;
        }
    }

    private double getProb(int i) {
        if (!this.probs.containsKey(Integer.valueOf(i))) {
            this.probs.put(Integer.valueOf(i), Double.valueOf(Math.pow(1.0d - this.p, i) * this.p));
        }
        return this.probs.get(Integer.valueOf(i)).doubleValue();
    }

    private double getCumProb(int i) {
        return 1.0d - Math.pow(1.0d - this.p, i + 1);
    }
}
