package org.data2semantics.mustard.kernels.graphkernels.graphlist;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.data2semantics.mustard.kernels.ComputationTimeTracker;
import org.data2semantics.mustard.kernels.FeatureInspector;
import org.data2semantics.mustard.kernels.KernelUtils;
import org.data2semantics.mustard.kernels.SparseVector;
import org.data2semantics.mustard.kernels.data.GraphList;
import org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel;
import org.data2semantics.mustard.kernels.graphkernels.GraphKernel;
import org.data2semantics.mustard.weisfeilerlehman.ApproxStringLabel;
import org.data2semantics.mustard.weisfeilerlehman.WLUtils;
import org.data2semantics.mustard.weisfeilerlehman.WeisfeilerLehmanApproxDTGraphIterator;
import org.nodes.DTGraph;
import org.nodes.DTLink;
import org.nodes.DTNode;
import org.nodes.LightDTGraph;
import org.nodes.TNode;

/* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/graphlist/WLSubTreeApproxKernel.class */
public class WLSubTreeApproxKernel implements GraphKernel<GraphList<DTGraph<ApproxStringLabel, ApproxStringLabel>>>, FeatureVectorKernel<GraphList<DTGraph<ApproxStringLabel, ApproxStringLabel>>>, ComputationTimeTracker, FeatureInspector {
    private int iterations;
    protected boolean normalize;
    private boolean reverse;
    private boolean noDuplicateNBH;
    private int[] minFreqs;
    private int[] maxLabelCards;
    private int[] maxPrevNBHs;
    private double depthWeight;
    private double depthDiffWeight;
    private int maxDepth;
    private long compTime;
    private Map<String, String> dict;
    private Map<String, Integer> labelFreq;

    public WLSubTreeApproxKernel(int i, boolean z, boolean z2, double d, double d2, int[] iArr, int[] iArr2, int[] iArr3, boolean z3) {
        this.reverse = z;
        this.noDuplicateNBH = z2;
        this.normalize = z3;
        this.iterations = i;
        this.maxPrevNBHs = iArr;
        this.maxLabelCards = iArr2;
        this.minFreqs = iArr3;
        this.depthWeight = d;
        this.depthDiffWeight = d2;
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public String getLabel() {
        return KernelUtils.createLabel(this);
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public void setNormalize(boolean z) {
        this.normalize = z;
    }

    @Override // org.data2semantics.mustard.kernels.ComputationTimeTracker
    public long getComputationTime() {
        return this.compTime;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel
    public SparseVector[] computeFeatureVectors(GraphList<DTGraph<ApproxStringLabel, ApproxStringLabel>> graphList) {
        SparseVector[] sparseVectorArr = new SparseVector[graphList.numInstances()];
        for (int i = 0; i < sparseVectorArr.length; i++) {
            sparseVectorArr[i] = new SparseVector();
        }
        WeisfeilerLehmanApproxDTGraphIterator weisfeilerLehmanApproxDTGraphIterator = new WeisfeilerLehmanApproxDTGraphIterator(this.reverse, 1, 1, 1);
        double length = this.minFreqs.length * this.maxLabelCards.length * this.maxPrevNBHs.length;
        long currentTimeMillis = System.currentTimeMillis();
        for (int i2 : this.minFreqs) {
            for (int i3 : this.maxLabelCards) {
                for (int i4 : this.maxPrevNBHs) {
                    List<DTGraph<ApproxStringLabel, ApproxStringLabel>> copyGraphs = copyGraphs(graphList.getGraphs());
                    weisfeilerLehmanApproxDTGraphIterator.wlInitialize(copyGraphs);
                    weisfeilerLehmanApproxDTGraphIterator.setMaxLabelCard(i3);
                    weisfeilerLehmanApproxDTGraphIterator.setMinFreq(i2);
                    weisfeilerLehmanApproxDTGraphIterator.setMaxPrevNBH(i4);
                    computeFVs(copyGraphs, sparseVectorArr, 1.0d / length, this.depthWeight, weisfeilerLehmanApproxDTGraphIterator.getLabelDict().size() - 1);
                    for (int i5 = 0; i5 < this.iterations; i5++) {
                        computeLabelFreqs(copyGraphs);
                        weisfeilerLehmanApproxDTGraphIterator.wlIterate(copyGraphs, this.labelFreq);
                        computeFVs(copyGraphs, sparseVectorArr, (1.0d + ((i5 + 1) * ((length - 1.0d) / this.iterations))) / length, this.depthWeight, weisfeilerLehmanApproxDTGraphIterator.getLabelDict().size() - 1);
                    }
                }
            }
        }
        this.compTime = System.currentTimeMillis() - currentTimeMillis;
        this.dict = new HashMap();
        for (String str : weisfeilerLehmanApproxDTGraphIterator.getLabelDict().keySet()) {
            this.dict.put(weisfeilerLehmanApproxDTGraphIterator.getLabelDict().get(str), str);
        }
        if (this.normalize) {
            sparseVectorArr = KernelUtils.normalize(sparseVectorArr);
        }
        return sparseVectorArr;
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.GraphKernel
    public double[][] compute(GraphList<DTGraph<ApproxStringLabel, ApproxStringLabel>> graphList) {
        return KernelUtils.computeKernelMatrix(computeFeatureVectors(graphList), KernelUtils.initMatrix(graphList.getGraphs().size(), graphList.getGraphs().size()));
    }

    private void computeFVs(List<DTGraph<ApproxStringLabel, ApproxStringLabel>> list, SparseVector[] sparseVectorArr, double d, double d2, int i) {
        for (int i2 = 0; i2 < list.size(); i2++) {
            sparseVectorArr[i2].setLastIndex((i * (this.maxDepth + 1)) + this.maxDepth);
            for (DTNode<ApproxStringLabel, ApproxStringLabel> dTNode : list.get(i2).nodes()) {
                String approxStringLabel = dTNode.label().toString();
                if (!this.noDuplicateNBH || dTNode.label().getSameAsPrev() == 0) {
                    int parseInt = Integer.parseInt(approxStringLabel);
                    for (int i3 = 0; i3 <= this.maxDepth; i3++) {
                        int i4 = (parseInt * (this.maxDepth + 1)) + i3;
                        sparseVectorArr[i2].setValue(i4, sparseVectorArr[i2].getValue(i4) + ((d / Math.pow(this.depthDiffWeight, Math.abs(i3 - dTNode.label().getDepth()))) / Math.pow(d2, i3)));
                    }
                }
            }
            for (DTLink<ApproxStringLabel, ApproxStringLabel> dTLink : list.get(i2).links()) {
                String approxStringLabel2 = dTLink.tag().toString();
                if (!this.noDuplicateNBH || dTLink.tag().getSameAsPrev() == 0) {
                    int parseInt2 = Integer.parseInt(approxStringLabel2);
                    for (int i5 = 0; i5 <= this.maxDepth; i5++) {
                        int i6 = (parseInt2 * (this.maxDepth + 1)) + i5;
                        sparseVectorArr[i2].setValue(i6, sparseVectorArr[i2].getValue(i6) + ((d / Math.pow(this.depthDiffWeight, Math.abs(i5 - dTLink.tag().getDepth()))) / Math.pow(d2, i5)));
                    }
                }
            }
        }
    }

    private void computeLabelFreqs(List<DTGraph<ApproxStringLabel, ApproxStringLabel>> list) {
        this.labelFreq = new HashMap();
        for (int i = 0; i < list.size(); i++) {
            HashSet hashSet = new HashSet();
            Iterator<? extends DTNode<ApproxStringLabel, ApproxStringLabel>> it = list.get(i).nodes().iterator();
            while (it.hasNext()) {
                String approxStringLabel = it.next().label().toString();
                if (!this.labelFreq.containsKey(approxStringLabel)) {
                    this.labelFreq.put(approxStringLabel, 0);
                }
                if (!hashSet.contains(approxStringLabel)) {
                    this.labelFreq.put(approxStringLabel, Integer.valueOf(this.labelFreq.get(approxStringLabel).intValue() + 1));
                    hashSet.add(approxStringLabel);
                }
            }
            Iterator<? extends DTLink<ApproxStringLabel, ApproxStringLabel>> it2 = list.get(i).links().iterator();
            while (it2.hasNext()) {
                String approxStringLabel2 = it2.next().tag().toString();
                if (!this.labelFreq.containsKey(approxStringLabel2)) {
                    this.labelFreq.put(approxStringLabel2, 0);
                }
                if (!hashSet.contains(approxStringLabel2)) {
                    this.labelFreq.put(approxStringLabel2, Integer.valueOf(this.labelFreq.get(approxStringLabel2).intValue() + 1));
                    hashSet.add(approxStringLabel2);
                }
            }
        }
    }

    private List<DTGraph<ApproxStringLabel, ApproxStringLabel>> copyGraphs(List<DTGraph<ApproxStringLabel, ApproxStringLabel>> list) {
        ArrayList arrayList = new ArrayList();
        this.maxDepth = 0;
        for (DTGraph<ApproxStringLabel, ApproxStringLabel> dTGraph : list) {
            LightDTGraph lightDTGraph = new LightDTGraph();
            for (DTNode<ApproxStringLabel, ApproxStringLabel> dTNode : dTGraph.nodes()) {
                lightDTGraph.add((LightDTGraph) new ApproxStringLabel(dTNode.label().toString(), dTNode.label().getDepth()));
                this.maxDepth = Math.max(this.maxDepth, dTNode.label().getDepth());
            }
            for (DTLink<ApproxStringLabel, ApproxStringLabel> dTLink : dTGraph.links()) {
                ((DTNode) lightDTGraph.nodes().get(dTLink.from().index())).connect((TNode<L, TNode>) lightDTGraph.nodes().get(dTLink.to().index()), (TNode) new ApproxStringLabel(dTLink.tag().toString(), dTLink.tag().getDepth()));
                this.maxDepth = Math.max(this.maxDepth, dTLink.tag().getDepth());
            }
            arrayList.add(lightDTGraph);
        }
        return arrayList;
    }

    @Override // org.data2semantics.mustard.kernels.FeatureInspector
    public List<String> getFeatureDescriptions(List<Integer> list) {
        if (this.dict == null) {
            throw new RuntimeException("Should run computeFeatureVectors() first");
        }
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(WLUtils.getFeatureDecription(this.dict, it.next().intValue()));
        }
        return arrayList;
    }
}
