package org.data2semantics.mustard.kernels.graphkernels.singledtgraph;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.data2semantics.mustard.kernels.ComputationTimeTracker;
import org.data2semantics.mustard.kernels.KernelUtils;
import org.data2semantics.mustard.kernels.SparseVector;
import org.data2semantics.mustard.kernels.data.SingleDTGraph;
import org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel;
import org.data2semantics.mustard.kernels.graphkernels.GraphKernel;
import org.data2semantics.mustard.utils.Pair;
import org.data2semantics.mustard.weisfeilerlehman.ApproxStringLabel;
import org.data2semantics.mustard.weisfeilerlehman.WeisfeilerLehmanApproxDTGraphIterator;
import org.nodes.DTGraph;
import org.nodes.DTLink;
import org.nodes.DTNode;
import org.nodes.LightDTGraph;
import org.nodes.TNode;

/* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/singledtgraph/DTGraphTreeWLSubTreeIDEQApproxKernel.class */
public class DTGraphTreeWLSubTreeIDEQApproxKernel implements GraphKernel<SingleDTGraph>, FeatureVectorKernel<SingleDTGraph>, ComputationTimeTracker {
    private Map<DTNode<ApproxStringLabel, ApproxStringLabel>, List<Pair<DTNode<ApproxStringLabel, ApproxStringLabel>, Integer>>> instanceVertexIndexMap;
    private Map<DTNode<ApproxStringLabel, ApproxStringLabel>, List<Pair<DTLink<ApproxStringLabel, ApproxStringLabel>, Integer>>> instanceEdgeIndexMap;
    private DTGraph<ApproxStringLabel, ApproxStringLabel> rdfGraph;
    private List<DTNode<ApproxStringLabel, ApproxStringLabel>> instanceVertices;
    private Map<String, Integer> labelFreq;
    private int depth;
    private int iterations;
    private boolean normalize;
    private boolean reverse;
    private boolean iterationWeighting;
    private boolean noDuplicateNBH;
    private boolean noSubGraphs;
    private int[] maxLabelCards;
    private int[] minFreqs;
    private int[] maxPrevNBHs;
    private double depthWeight;
    private double depthDiffWeight;
    private long compTime;

    public DTGraphTreeWLSubTreeIDEQApproxKernel(int i, int i2, boolean z, boolean z2, boolean z3, boolean z4, double d, double d2, int[] iArr, int[] iArr2, int[] iArr3, boolean z5) {
        this.reverse = z;
        this.iterationWeighting = z2;
        this.noDuplicateNBH = z3;
        this.noSubGraphs = z4;
        this.normalize = z5;
        this.depth = i2;
        this.iterations = i;
        this.maxLabelCards = iArr2;
        this.minFreqs = iArr3;
        this.maxPrevNBHs = iArr;
        this.depthWeight = d;
        this.depthDiffWeight = d2;
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public String getLabel() {
        return KernelUtils.createLabel(this);
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    @Override // org.data2semantics.mustard.kernels.ComputationTimeTracker
    public long getComputationTime() {
        return this.compTime;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel
    public SparseVector[] computeFeatureVectors(SingleDTGraph singleDTGraph) {
        SparseVector[] sparseVectorArr = new SparseVector[singleDTGraph.getInstances().size()];
        for (int i = 0; i < sparseVectorArr.length; i++) {
            sparseVectorArr[i] = new SparseVector();
        }
        WeisfeilerLehmanApproxDTGraphIterator weisfeilerLehmanApproxDTGraphIterator = new WeisfeilerLehmanApproxDTGraphIterator(this.reverse, 1, 1, 1);
        double length = this.minFreqs.length * this.maxLabelCards.length * this.maxPrevNBHs.length;
        long currentTimeMillis = System.currentTimeMillis();
        for (int i2 : this.minFreqs) {
            for (int i3 : this.maxLabelCards) {
                for (int i4 : this.maxPrevNBHs) {
                    init(singleDTGraph.getGraph(), singleDTGraph.getInstances());
                    ArrayList arrayList = new ArrayList();
                    arrayList.add(this.rdfGraph);
                    weisfeilerLehmanApproxDTGraphIterator.wlInitialize(arrayList);
                    weisfeilerLehmanApproxDTGraphIterator.setMaxLabelCard(i3);
                    weisfeilerLehmanApproxDTGraphIterator.setMinFreq(i2);
                    weisfeilerLehmanApproxDTGraphIterator.setMaxPrevNBH(i4);
                    computeFVs(this.rdfGraph, this.instanceVertices, 1.0d / length, sparseVectorArr, weisfeilerLehmanApproxDTGraphIterator.getLabelDict().size() - 1, 0);
                    for (int i5 = 0; i5 < this.iterations; i5++) {
                        computeLabelFreqs(this.rdfGraph, this.instanceVertices);
                        weisfeilerLehmanApproxDTGraphIterator.wlIterate(arrayList, this.labelFreq);
                        computeFVs(this.rdfGraph, this.instanceVertices, (1.0d + ((i5 + 1) * ((length - 1.0d) / this.iterations))) / length, sparseVectorArr, weisfeilerLehmanApproxDTGraphIterator.getLabelDict().size() - 1, i5 + 1);
                    }
                }
            }
        }
        this.compTime = System.currentTimeMillis() - currentTimeMillis;
        if (this.normalize) {
            sparseVectorArr = KernelUtils.normalize(sparseVectorArr);
        }
        return sparseVectorArr;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.GraphKernel
    public double[][] compute(SingleDTGraph singleDTGraph) {
        SparseVector[] computeFeatureVectors = computeFeatureVectors(singleDTGraph);
        double[][] initMatrix = KernelUtils.initMatrix(singleDTGraph.getInstances().size(), singleDTGraph.getInstances().size());
        long currentTimeMillis = System.currentTimeMillis();
        double[][] computeKernelMatrix = KernelUtils.computeKernelMatrix(computeFeatureVectors, initMatrix);
        this.compTime += System.currentTimeMillis() - currentTimeMillis;
        return computeKernelMatrix;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void init(DTGraph<String, String> dTGraph, List<DTNode<String, String>> list) {
        DTNode<ApproxStringLabel, ApproxStringLabel> add;
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        this.instanceVertices = new ArrayList();
        this.instanceVertexIndexMap = new HashMap();
        this.instanceEdgeIndexMap = new HashMap();
        this.rdfGraph = new LightDTGraph();
        for (DTNode<String, String> dTNode : list) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            if (hashMap.containsKey(dTNode)) {
                add = (DTNode) hashMap.get(dTNode);
            } else {
                add = this.rdfGraph.add((DTGraph<ApproxStringLabel, ApproxStringLabel>) new ApproxStringLabel());
                hashMap.put(dTNode, add);
            }
            add.label().clear();
            add.label().append(dTNode.label());
            this.instanceVertices.add(add);
            this.instanceVertexIndexMap.put(add, arrayList);
            this.instanceEdgeIndexMap.put(add, arrayList2);
            ArrayList<DTNode> arrayList3 = new ArrayList();
            arrayList3.add(dTNode);
            arrayList.add(new Pair(add, Integer.valueOf(this.depth)));
            for (int i = this.depth - 1; i >= 0; i--) {
                ArrayList arrayList4 = new ArrayList();
                for (DTNode dTNode2 : arrayList3) {
                    for (DTLink dTLink : dTNode2.linksOut()) {
                        if (hashMap.containsKey(dTLink.to())) {
                            arrayList.add(new Pair((DTNode) hashMap.get(dTLink.to()), Integer.valueOf(i)));
                            ((ApproxStringLabel) ((DTNode) hashMap.get(dTLink.to())).label()).clear();
                            ((ApproxStringLabel) ((DTNode) hashMap.get(dTLink.to())).label()).append((String) dTLink.to().label());
                        } else {
                            DTNode<ApproxStringLabel, ApproxStringLabel> add2 = this.rdfGraph.add((DTGraph<ApproxStringLabel, ApproxStringLabel>) new ApproxStringLabel());
                            add2.label().clear();
                            add2.label().append((String) dTLink.to().label());
                            hashMap.put(dTLink.to(), add2);
                            arrayList.add(new Pair(add2, Integer.valueOf(i)));
                        }
                        if (hashMap2.containsKey(dTLink)) {
                            arrayList2.add(new Pair((DTLink) hashMap2.get(dTLink), Integer.valueOf(i)));
                            ((ApproxStringLabel) ((DTLink) hashMap2.get(dTLink)).tag()).clear();
                            ((ApproxStringLabel) ((DTLink) hashMap2.get(dTLink)).tag()).append((String) dTLink.tag());
                        } else {
                            DTLink connect = ((DTNode) hashMap.get(dTNode2)).connect((TNode<L, TNode>) hashMap.get(dTLink.to()), (TNode) new ApproxStringLabel());
                            ((ApproxStringLabel) connect.tag()).clear();
                            ((ApproxStringLabel) connect.tag()).append((String) dTLink.tag());
                            hashMap2.put(dTLink, connect);
                            arrayList2.add(new Pair(connect, Integer.valueOf(i)));
                        }
                        if (i > 0) {
                            arrayList4.add(dTLink.to());
                        }
                    }
                }
                arrayList3 = arrayList4;
            }
        }
    }

    private void computeFVs(DTGraph<ApproxStringLabel, ApproxStringLabel> dTGraph, List<DTNode<ApproxStringLabel, ApproxStringLabel>> list, double d, SparseVector[] sparseVectorArr, int i, int i2) {
        for (int i3 = 0; i3 < list.size(); i3++) {
            sparseVectorArr[i3].setLastIndex((i * (this.depth + 1)) + this.depth);
            for (Pair<DTNode<ApproxStringLabel, ApproxStringLabel>, Integer> pair : this.instanceVertexIndexMap.get(list.get(i3))) {
                if (!this.noDuplicateNBH || pair.getFirst().label().getSameAsPrev() == 0) {
                    if (this.noSubGraphs || this.depth * 2 >= i2) {
                        int parseInt = Integer.parseInt(pair.getFirst().label().toString());
                        for (int i4 = 0; i4 <= this.depth; i4++) {
                            int i5 = (parseInt * (this.depth + 1)) + i4;
                            sparseVectorArr[i3].setValue(i5, sparseVectorArr[i3].getValue(i5) + ((d / Math.pow(this.depthDiffWeight, Math.abs(i4 - this.depth))) / Math.pow(this.depthWeight, i4)));
                        }
                    }
                }
            }
            for (Pair<DTLink<ApproxStringLabel, ApproxStringLabel>, Integer> pair2 : this.instanceEdgeIndexMap.get(list.get(i3))) {
                if (!this.noDuplicateNBH || pair2.getFirst().tag().getSameAsPrev() == 0) {
                    if (this.noSubGraphs || (this.depth * 2) + 1 >= i2) {
                        int parseInt2 = Integer.parseInt(pair2.getFirst().tag().toString());
                        for (int i6 = 0; i6 <= this.depth; i6++) {
                            int i7 = (parseInt2 * (this.depth + 1)) + i6;
                            sparseVectorArr[i3].setValue(i7, sparseVectorArr[i3].getValue(i7) + ((d / Math.pow(this.depthDiffWeight, Math.abs(i6 - this.depth))) / Math.pow(this.depthWeight, i6)));
                        }
                    }
                }
            }
        }
    }

    private void computeLabelFreqs(DTGraph<ApproxStringLabel, ApproxStringLabel> dTGraph, List<DTNode<ApproxStringLabel, ApproxStringLabel>> list) {
        this.labelFreq = new HashMap();
        for (int i = 0; i < list.size(); i++) {
            HashSet hashSet = new HashSet();
            Iterator<Pair<DTNode<ApproxStringLabel, ApproxStringLabel>, Integer>> it = this.instanceVertexIndexMap.get(list.get(i)).iterator();
            while (it.hasNext()) {
                String approxStringLabel = it.next().getFirst().label().toString();
                if (!this.labelFreq.containsKey(approxStringLabel)) {
                    this.labelFreq.put(approxStringLabel, 0);
                }
                if (!hashSet.contains(approxStringLabel)) {
                    this.labelFreq.put(approxStringLabel, Integer.valueOf(this.labelFreq.get(approxStringLabel).intValue() + 1));
                    hashSet.add(approxStringLabel);
                }
            }
            Iterator<Pair<DTLink<ApproxStringLabel, ApproxStringLabel>, Integer>> it2 = this.instanceEdgeIndexMap.get(list.get(i)).iterator();
            while (it2.hasNext()) {
                String approxStringLabel2 = it2.next().getFirst().tag().toString();
                if (!this.labelFreq.containsKey(approxStringLabel2)) {
                    this.labelFreq.put(approxStringLabel2, 0);
                }
                if (!hashSet.contains(approxStringLabel2)) {
                    this.labelFreq.put(approxStringLabel2, Integer.valueOf(this.labelFreq.get(approxStringLabel2).intValue() + 1));
                    hashSet.add(approxStringLabel2);
                }
            }
        }
    }
}
