package org.webdatacommons.webtables.extraction;

import com.google.common.base.Optional;
import org.jsoup.nodes.Element;
import org.webdatacommons.webtables.extraction.model.ClassificationResult;
import org.webdatacommons.webtables.extraction.model.FeaturesP1;
import org.webdatacommons.webtables.extraction.model.FeaturesP2;
import org.webdatacommons.webtables.extraction.util.TableConvert;
import org.webdatacommons.webtables.tools.data.TableType;
import weka.classifiers.Classifier;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.SerializationHelper;

/* loaded from: input_file:org/webdatacommons/webtables/extraction/TableClassification.class */
public class TableClassification {
    private TableConvert tableConvert = new TableConvert(2, 2);
    private FeaturesP1 phase1Features = new FeaturesP1();
    private FeaturesP2 phase2Features = new FeaturesP2();
    private Classifier classifier1;
    private Classifier classifier2;
    private Attribute classAttr1;
    private Attribute classAttr2;
    private double layoutVal;
    private double relationVal;
    private double entityVal;
    private double matrixVal;
    private double noneVal;

    public TableClassification(String str, String str2) {
        try {
            this.classifier1 = loadModelFromClasspath(str);
            this.classifier2 = loadModelFromClasspath(str2);
        } catch (Exception e) {
            e.printStackTrace();
        }
        this.classAttr1 = new Attribute("class", this.phase1Features.getClassVector());
        this.layoutVal = this.classAttr1.indexOfValue("LAYOUT");
        this.classAttr2 = new Attribute("class", this.phase2Features.getClassVector());
        this.relationVal = this.classAttr2.indexOfValue("RELATION");
        this.entityVal = this.classAttr2.indexOfValue("ENTITY");
        this.matrixVal = this.classAttr2.indexOfValue("MATRIX");
        this.noneVal = this.classAttr2.indexOfValue("NONE");
    }

    public ClassificationResult classifyTable(Element element) {
        Optional<Element[][]> table = this.tableConvert.toTable(element);
        return !table.isPresent() ? new ClassificationResult(TableType.LAYOUT, new double[]{1.0d, KStarConstants.FLOOR, KStarConstants.FLOOR, KStarConstants.FLOOR, KStarConstants.FLOOR}, null) : classifyTable(table.get());
    }

    public ClassificationResult classifyTable(Element[][] elementArr) {
        Instance computeFeatures = this.phase1Features.computeFeatures(elementArr);
        try {
            double classifyInstance = this.classifier1.classifyInstance(computeFeatures);
            double[] distributionForInstance = this.classifier1.distributionForInstance(computeFeatures);
            if (classifyInstance == this.layoutVal) {
                return new ClassificationResult(TableType.LAYOUT, distributionForInstance, null);
            }
            Instance computeFeatures2 = this.phase2Features.computeFeatures(elementArr);
            double classifyInstance2 = this.classifier2.classifyInstance(computeFeatures2);
            return new ClassificationResult(classifyInstance2 == this.relationVal ? TableType.RELATION : classifyInstance2 == this.entityVal ? TableType.ENTITY : classifyInstance2 == this.matrixVal ? TableType.MATRIX : classifyInstance2 == this.noneVal ? TableType.OTHER : TableType.LAYOUT, distributionForInstance, this.classifier2.distributionForInstance(computeFeatures2));
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    public static Classifier loadModelFromFile(String str) throws Exception {
        return (Classifier) SerializationHelper.read(str);
    }

    public static Classifier loadModelFromClasspath(String str) throws Exception {
        return (Classifier) SerializationHelper.read(TableClassification.class.getResourceAsStream(str));
    }
}
