package de.tu_dortmund.sfb876.optimplugin.regularizers;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/tu_dortmund/sfb876/optimplugin/regularizers/GroupL1.class */
public class GroupL1 implements Regularizer {
    private double lambda;
    private HashMap<Integer, ArrayList<Integer>> groupMap;

    public GroupL1(double d, HashMap<Integer, ArrayList<Integer>> hashMap) {
        this.lambda = d;
        this.groupMap = hashMap;
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.regularizers.Regularizer
    public double addtoCost(RealVector realVector) {
        double d = 0.0d;
        Iterator<Integer> it = this.groupMap.keySet().iterator();
        while (it.hasNext()) {
            d += getGroupL2Norm(this.groupMap.get(it.next()), realVector);
        }
        return this.lambda * d;
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.regularizers.Regularizer
    public RealVector getGradient(RealVector realVector) {
        ArrayRealVector arrayRealVector = new ArrayRealVector(realVector.getDimension());
        for (Integer num : this.groupMap.keySet()) {
            double groupL2Norm = getGroupL2Norm(this.groupMap.get(num), realVector);
            Iterator<Integer> it = this.groupMap.get(num).iterator();
            while (it.hasNext()) {
                Integer next = it.next();
                if (groupL2Norm == CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    arrayRealVector.setEntry(next.intValue(), CMAESOptimizer.DEFAULT_STOPFITNESS);
                } else {
                    arrayRealVector.setEntry(next.intValue(), (realVector.getEntry(next.intValue()) * this.lambda) / groupL2Norm);
                }
            }
        }
        return arrayRealVector;
    }

    public double getGroupL2Norm(ArrayList<Integer> arrayList, RealVector realVector) {
        double d = 0.0d;
        Iterator<Integer> it = arrayList.iterator();
        while (it.hasNext()) {
            d += Math.pow(realVector.getEntry(it.next().intValue()), 2.0d);
        }
        return Math.sqrt(d);
    }

    @Override // de.tu_dortmund.sfb876.optimplugin.regularizers.Regularizer
    public void logConfiguration() {
        Logger logger = LoggerFactory.getLogger(L1Regularizer.class);
        logger.info("lambda: {}", Double.valueOf(this.lambda));
        logger.info("Group configuration");
        ArrayList arrayList = new ArrayList(this.groupMap.keySet());
        Collections.sort(arrayList);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            Integer num = (Integer) it.next();
            StringBuffer stringBuffer = new StringBuffer();
            ArrayList<Integer> arrayList2 = this.groupMap.get(num);
            Collections.sort(arrayList2);
            Iterator<Integer> it2 = arrayList2.iterator();
            while (it2.hasNext()) {
                stringBuffer.append(it2.next()).append(",");
            }
            stringBuffer.deleteCharAt(stringBuffer.length() - 1);
            logger.info("Members for the Group-{}: {}", num, stringBuffer.toString());
        }
    }
}
