package org.data2semantics.mustard.kernels.graphkernels.rdfdata;

import java.util.List;
import java.util.Set;
import org.data2semantics.mustard.kernels.ComputationTimeTracker;
import org.data2semantics.mustard.kernels.KernelUtils;
import org.data2semantics.mustard.kernels.SparseVector;
import org.data2semantics.mustard.kernels.data.RDFData;
import org.data2semantics.mustard.kernels.data.SingleDTGraph;
import org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel;
import org.data2semantics.mustard.kernels.graphkernels.GraphKernel;
import org.data2semantics.mustard.kernels.graphkernels.singledtgraph.DTGraphTreeWalkCountKernelMkII;
import org.data2semantics.mustard.rdf.RDFDataSet;
import org.data2semantics.mustard.rdf.RDFUtils;
import org.openrdf.model.Resource;
import org.openrdf.model.Statement;

/* loaded from: input_file:org/data2semantics/mustard/kernels/graphkernels/rdfdata/RDFTreeWalkCountKernelMkII.class */
public class RDFTreeWalkCountKernelMkII implements GraphKernel<RDFData>, FeatureVectorKernel<RDFData>, ComputationTimeTracker {
    private int depth;
    private boolean inference;
    private DTGraphTreeWalkCountKernelMkII kernel;
    private SingleDTGraph graph;

    public RDFTreeWalkCountKernelMkII(int i, int i2, boolean z, boolean z2) {
        this.depth = i2;
        this.inference = z;
        this.kernel = new DTGraphTreeWalkCountKernelMkII(i, i2, z2);
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public String getLabel() {
        return String.valueOf(KernelUtils.createLabel(this)) + "_" + this.kernel.getLabel();
    }

    @Override // org.data2semantics.mustard.kernels.Kernel
    public void setNormalize(boolean z) {
        this.kernel.setNormalize(z);
    }

    @Override // org.data2semantics.mustard.kernels.ComputationTimeTracker
    public long getComputationTime() {
        return this.kernel.getComputationTime();
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.FeatureVectorKernel
    public SparseVector[] computeFeatureVectors(RDFData rDFData) {
        init(rDFData.getDataset(), rDFData.getInstances(), rDFData.getBlackList());
        return this.kernel.computeFeatureVectors(this.graph);
    }

    @Override // org.data2semantics.mustard.kernels.graphkernels.GraphKernel
    public double[][] compute(RDFData rDFData) {
        init(rDFData.getDataset(), rDFData.getInstances(), rDFData.getBlackList());
        return this.kernel.compute(this.graph);
    }

    private void init(RDFDataSet rDFDataSet, List<Resource> list, List<Statement> list2) {
        Set<Statement> statements4Depth = RDFUtils.getStatements4Depth(rDFDataSet, list, this.depth, this.inference);
        statements4Depth.removeAll(list2);
        this.graph = RDFUtils.statements2Graph(statements4Depth, 3, list, true);
    }
}
