package com.quantx1.financial.analytics;

import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix2D;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.table.DoubleArrayDataRow;
import com.rapidminer.example.table.ListDataRowReader;
import com.rapidminer.example.table.MemoryExampleTable;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.features.construction.AbstractFeatureConstruction;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.MetaData;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeAttribute;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.UndefinedParameterError;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/quantx1/financial/analytics/BacktestOperator.class */
public class BacktestOperator extends AbstractFeatureConstruction {
    public static final String PARAMETER_SECURITY_ID = "Security Id";
    public static final String PARAMETER_SIGNAL_COL = "Signal";
    public static final String PARAMETER_PERIOD_COL = "Period";
    public static final String PARAMETER_RETURN_COL = "Return Series";
    public static final String PARAMETER_NUM_BUCKETS = "Number of buckets";
    public static final String SPREAD_FORMAT = "Q%d-Q%d";
    public static final String TURNOVER_ATTR = "turnover";
    public static final String IC_ATTR = "IC";
    private Map<String, Integer> periods_set;
    private String[] periods;
    private boolean is_period_nominal;
    private boolean is_period_date;

    public BacktestOperator(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.is_period_nominal = false;
        this.is_period_date = false;
    }

    private String getPeriodColumn() throws UndefinedParameterError {
        return getParameterAsString(PARAMETER_PERIOD_COL);
    }

    protected MetaData modifyMetaData(ExampleSetMetaData exampleSetMetaData) {
        try {
            exampleSetMetaData.clear();
            int parameterAsInt = getParameterAsInt(PARAMETER_NUM_BUCKETS);
            exampleSetMetaData.addAttribute(new AttributeMetaData(getPeriodColumn(), 1));
            for (int i = 1; i < parameterAsInt + 1; i++) {
                exampleSetMetaData.addAttribute(new AttributeMetaData(String.format("Q%d", Integer.valueOf(i)), 2));
            }
            exampleSetMetaData.addAttribute(new AttributeMetaData(String.format(SPREAD_FORMAT, Integer.valueOf(parameterAsInt), 1), 2));
            exampleSetMetaData.addAttribute(new AttributeMetaData(TURNOVER_ATTR, 2));
            exampleSetMetaData.addAttribute(new AttributeMetaData(IC_ATTR, 2));
        } catch (UndefinedParameterError e) {
        }
        return exampleSetMetaData;
    }

    private ExampleSet convertBacktestToExampleSet(Backtest backtest) throws UndefinedParameterError {
        DoubleMatrix2D bucketReturns = backtest.getBucketReturns();
        int nBuckets = backtest.getNBuckets();
        HashMap hashMap = new HashMap();
        hashMap.put("period", 0);
        hashMap.put("quantile", 1);
        hashMap.put("spread", Integer.valueOf(((Integer) hashMap.get("quantile")).intValue() + nBuckets));
        hashMap.put(TURNOVER_ATTR, Integer.valueOf(((Integer) hashMap.get("spread")).intValue() + 1));
        hashMap.put("ic", Integer.valueOf(((Integer) hashMap.get(TURNOVER_ATTR)).intValue() + 1));
        int i = 0;
        Iterator it = hashMap.values().iterator();
        while (it.hasNext()) {
            i = Math.max(i, ((Integer) it.next()).intValue());
        }
        int i2 = i + 1;
        ArrayList arrayList = new ArrayList(i2);
        Attribute createAttribute = AttributeFactory.createAttribute(getPeriodColumn(), 1);
        arrayList.add(((Integer) hashMap.get("period")).intValue(), createAttribute);
        for (int intValue = ((Integer) hashMap.get("quantile")).intValue(); intValue <= nBuckets; intValue++) {
            arrayList.add(intValue, AttributeFactory.createAttribute(String.format("Q%d", Integer.valueOf(intValue)), 2));
        }
        arrayList.add(((Integer) hashMap.get("spread")).intValue(), AttributeFactory.createAttribute(String.format(SPREAD_FORMAT, Integer.valueOf(nBuckets), 1), 2));
        arrayList.add(((Integer) hashMap.get(TURNOVER_ATTR)).intValue(), AttributeFactory.createAttribute(TURNOVER_ATTR, 2));
        double[] calculateTurnover = backtest.calculateTurnover();
        arrayList.add(((Integer) hashMap.get("ic")).intValue(), AttributeFactory.createAttribute(IC_ATTR, 2));
        double[] array = backtest.calculateICs().toArray();
        ArrayList arrayList2 = new ArrayList(bucketReturns.rows());
        for (int i3 = 0; i3 < bucketReturns.rows(); i3++) {
            double[] dArr = new double[i2];
            for (int i4 = 0; i4 < nBuckets; i4++) {
                dArr[((Integer) hashMap.get("quantile")).intValue() + i4] = bucketReturns.get(i3, i4);
            }
            dArr[((Integer) hashMap.get("spread")).intValue()] = dArr[nBuckets] - dArr[((Integer) hashMap.get("quantile")).intValue()];
            dArr[((Integer) hashMap.get(TURNOVER_ATTR)).intValue()] = calculateTurnover[i3];
            dArr[((Integer) hashMap.get("ic")).intValue()] = array[i3];
            arrayList2.add(i3, new DoubleArrayDataRow(dArr));
        }
        ExampleSet createExampleSet = new MemoryExampleTable(arrayList, new ListDataRowReader(arrayList2.iterator())).createExampleSet();
        int i5 = 0;
        Iterator it2 = createExampleSet.iterator();
        while (it2.hasNext()) {
            ((Example) it2.next()).setValue(createAttribute, this.periods[i5]);
            i5++;
        }
        return createExampleSet;
    }

    private void computePeriods(ExampleSet exampleSet) throws OperatorException {
        Attribute attribute = exampleSet.getAttributes().get(getParameterAsString(PARAMETER_PERIOD_COL));
        this.periods_set = new HashMap();
        if (attribute.isNominal()) {
            this.is_period_nominal = true;
        } else if (attribute.getValueType() == 10) {
            this.is_period_date = true;
        }
        if (this.is_period_nominal) {
            HashSet hashSet = new HashSet();
            Iterator it = exampleSet.iterator();
            while (it.hasNext()) {
                String nominalValue = ((Example) it.next()).getNominalValue(attribute);
                if (!hashSet.contains(nominalValue)) {
                    hashSet.add(nominalValue);
                }
            }
            this.periods_set = ArrayIndexComparator.createPositionsMap((String[]) hashSet.toArray(new String[0]));
        } else if (this.is_period_date) {
            HashSet hashSet2 = new HashSet();
            Iterator it2 = exampleSet.iterator();
            while (it2.hasNext()) {
                Date dateValue = ((Example) it2.next()).getDateValue(attribute);
                if (!hashSet2.contains(dateValue)) {
                    hashSet2.add(dateValue);
                }
            }
            this.periods_set = ArrayIndexComparator.createStringPositionsMap((Date[]) hashSet2.toArray(new Date[0]));
        } else {
            HashSet hashSet3 = new HashSet();
            Iterator it3 = exampleSet.iterator();
            while (it3.hasNext()) {
                Double valueOf = Double.valueOf(((Example) it3.next()).getValue(attribute));
                if (!hashSet3.contains(valueOf)) {
                    hashSet3.add(valueOf);
                }
            }
            this.periods_set = ArrayIndexComparator.createStringPositionsMap((Double[]) hashSet3.toArray(new Double[0]));
        }
        this.periods = new String[this.periods_set.size()];
        for (Map.Entry<String, Integer> entry : this.periods_set.entrySet()) {
            this.periods[entry.getValue().intValue()] = entry.getKey();
        }
    }

    private Backtest createBacktestFromExampleSet(ExampleSet exampleSet) throws OperatorException {
        int intValue;
        Attribute attribute = exampleSet.getAttributes().get(getParameterAsString(PARAMETER_SECURITY_ID));
        Attribute attribute2 = exampleSet.getAttributes().get(getParameterAsString(PARAMETER_PERIOD_COL));
        Attribute attribute3 = exampleSet.getAttributes().get(getParameterAsString(PARAMETER_SIGNAL_COL));
        Attribute attribute4 = exampleSet.getAttributes().get(getParameterAsString(PARAMETER_RETURN_COL));
        int i = 0;
        HashSet hashSet = new HashSet();
        Iterator it = exampleSet.iterator();
        while (it.hasNext()) {
            String nominalValue = ((Example) it.next()).getNominalValue(attribute);
            if (!hashSet.contains(nominalValue)) {
                hashSet.add(nominalValue);
            }
        }
        HashMap hashMap = new HashMap();
        Iterator it2 = hashSet.iterator();
        while (it2.hasNext()) {
            int i2 = i;
            i++;
            hashMap.put((String) it2.next(), Integer.valueOf(i2));
        }
        DoubleFactory2D doubleFactory2D = DoubleFactory2D.dense;
        DoubleMatrix2D make = doubleFactory2D.make(this.periods_set.size(), hashMap.size(), Double.NaN);
        DoubleMatrix2D make2 = doubleFactory2D.make(this.periods_set.size(), hashMap.size(), Double.NaN);
        SimpleDateFormat simpleDateFormat = new SimpleDateFormat("yyyy-MM-dd");
        Iterator it3 = exampleSet.iterator();
        while (it3.hasNext()) {
            Example example = (Example) it3.next();
            if (this.is_period_nominal) {
                intValue = this.periods_set.get(example.getNominalValue(attribute2)).intValue();
            } else if (this.is_period_date) {
                intValue = this.periods_set.get(simpleDateFormat.format(example.getDateValue(attribute2))).intValue();
            } else {
                intValue = this.periods_set.get(new Double(example.getValue(attribute2)).toString()).intValue();
            }
            int i3 = intValue;
            int intValue2 = ((Integer) hashMap.get(example.getNominalValue(attribute))).intValue();
            make.set(i3, intValue2, example.getNumericalValue(attribute3));
            make2.set(i3, intValue2, example.getNumericalValue(attribute4));
        }
        String[] strArr = new String[hashMap.size()];
        for (Map.Entry entry : hashMap.entrySet()) {
            strArr[((Integer) entry.getValue()).intValue()] = (String) entry.getKey();
        }
        return new Backtest(make, make2, getParameterAsInt(PARAMETER_NUM_BUCKETS), strArr, this.periods);
    }

    public ExampleSet apply(ExampleSet exampleSet) throws OperatorException {
        computePeriods(exampleSet);
        Backtest createBacktestFromExampleSet = createBacktestFromExampleSet(exampleSet);
        createBacktestFromExampleSet.performBacktest();
        return convertBacktestToExampleSet(createBacktestFromExampleSet);
    }

    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        ParameterTypeAttribute parameterTypeAttribute = new ParameterTypeAttribute(PARAMETER_SECURITY_ID, "The column containing the security identifiers.", getExampleSetInputPort(), false, false);
        parameterTypeAttribute.setDefaultValue("Security ID");
        parameterTypes.add(parameterTypeAttribute);
        ParameterTypeAttribute parameterTypeAttribute2 = new ParameterTypeAttribute(PARAMETER_SIGNAL_COL, "The column containing the expected return.", getExampleSetInputPort(), false, false);
        parameterTypeAttribute2.setDefaultValue(PARAMETER_SIGNAL_COL);
        parameterTypes.add(parameterTypeAttribute2);
        ParameterTypeAttribute parameterTypeAttribute3 = new ParameterTypeAttribute(PARAMETER_PERIOD_COL, "The column containing the period (Date).", getExampleSetInputPort(), false, false);
        parameterTypeAttribute3.setDefaultValue(PARAMETER_PERIOD_COL);
        parameterTypes.add(parameterTypeAttribute3);
        ParameterTypeAttribute parameterTypeAttribute4 = new ParameterTypeAttribute(PARAMETER_RETURN_COL, "The column containing the forward returns.", getExampleSetInputPort(), false, false);
        parameterTypeAttribute4.setDefaultValue("Returns");
        parameterTypes.add(parameterTypeAttribute4);
        parameterTypes.add(new ParameterTypeInt(PARAMETER_NUM_BUCKETS, "The number of groups (buckets, phi-quantiles) to divide the securities into each period.", 1, Integer.MAX_VALUE, 5, false));
        return parameterTypes;
    }
}
