package com.rapidminer.ItemRecommendation;

import com.rapidminer.RatingPrediction.IIterativeModel;
import com.rapidminer.eval.ItemPrediction;
import com.rapidminer.matrixUtils.MatrixUtils;
import com.rapidminer.matrixUtils.VectorUtils;
import com.rapidminer.operator.Annotations;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.ProcessingStep;
import com.rapidminer.tools.LogService;
import com.rapidminer.tools.LoggingHandler;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:com/rapidminer/ItemRecommendation/BPRMF.class */
public class BPRMF extends MF implements IIterativeModel {
    static final long serialVersionUID = 3232342;
    protected double[] item_bias;
    public double BiasReg;
    protected List<List<Integer>> user_pos_items;
    protected List<List<Integer>> user_neg_items;
    public boolean BoldDriver;
    int[] loss_sample_u;
    int[] loss_sample_i;
    int[] loss_sample_j;
    protected com.rapidminer.utils.Random random;
    private transient LoggingHandler loggingHandler;
    protected boolean fast_sampling = false;
    protected int fast_sampling_memory_limit = 1024;
    protected double learn_rate = 0.05d;
    protected double reg_u = 0.0025d;
    protected double reg_i = 0.0025d;
    protected double reg_j = 2.5E-4d;
    double last_loss = Double.NEGATIVE_INFINITY;
    private String source = null;
    private transient LinkedList<ProcessingStep> processingHistory = new LinkedList<>();

    public int GetFastSamplingMemoryLimit() {
        return this.fast_sampling_memory_limit;
    }

    public void SetFastSamplingMemoryLimit(int i) {
        this.fast_sampling_memory_limit = i;
    }

    public double GetLearnRate() {
        return this.learn_rate;
    }

    public void SetLearnRate(double d) {
        this.learn_rate = d;
    }

    public double GetRegU() {
        return this.reg_u;
    }

    public void SetRegU(double d) {
        this.reg_u = d;
    }

    public double GetRegI() {
        return this.reg_i;
    }

    public void SetRegI(double d) {
        this.reg_i = d;
    }

    public double GetRegJ() {
        return this.reg_j;
    }

    public void SetRegJ(double d) {
        this.reg_j = d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.rapidminer.ItemRecommendation.MF
    public void InitModel() {
        super.InitModel();
        System.out.println("MaxItemID unutar inicijalizacije " + this.MaxItemID);
        this.item_bias = new double[this.MaxItemID + 1];
        CheckSampling();
    }

    @Override // com.rapidminer.ItemRecommendation.MF, com.rapidminer.ItemRecommendation.ItemRecommender
    public void Train() {
        InitModel();
        this.random = com.rapidminer.utils.Random.GetInstance();
        if (this.BoldDriver) {
            int sqrt = ((int) Math.sqrt(GetFeedback().GetMaxUserID())) * 100;
            this.loss_sample_u = new int[sqrt];
            this.loss_sample_i = new int[sqrt];
            this.loss_sample_j = new int[sqrt];
            for (int i = 0; i < sqrt; i++) {
                SampleTriple(0, 0, 0);
                this.loss_sample_u[i] = 0;
                this.loss_sample_i[i] = 0;
                this.loss_sample_j[i] = 0;
            }
            this.last_loss = ComputeLoss();
        }
        for (int i2 = 0; i2 < this.NumIter; i2++) {
            Iterate();
        }
    }

    @Override // com.rapidminer.ItemRecommendation.MF, com.rapidminer.RatingPrediction.IIterativeModel
    public void Iterate() {
        int Count = GetFeedback().Count();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < Count; i4++) {
            int[] SampleTriple = SampleTriple(i, i2, i3);
            i = SampleTriple[0];
            i2 = SampleTriple[1];
            i3 = SampleTriple[2];
            UpdateFactors(SampleTriple[0], SampleTriple[1], SampleTriple[2], true, true, true);
        }
        if (this.BoldDriver) {
            double ComputeLoss = ComputeLoss();
            if (ComputeLoss > this.last_loss) {
                SetLearnRate(GetLearnRate() * 0.5d);
            } else if (ComputeLoss < this.last_loss) {
                SetLearnRate(GetLearnRate() * 1.1d);
            }
            this.last_loss = ComputeLoss;
        }
    }

    protected boolean SampleOtherItem(int i, int i2, int i3) {
        boolean location = GetFeedback().GetUserMatrix().getLocation(i, i2);
        if (this.fast_sampling) {
            if (location) {
                this.user_neg_items.get(i).get(this.random.nextInt(this.user_neg_items.get(i).size())).intValue();
            } else {
                this.user_pos_items.get(i).get(this.random.nextInt(this.user_pos_items.get(i).size())).intValue();
            }
            return location;
        }
        do {
        } while (GetFeedback().GetUserMatrix().getLocation(i, this.random.nextInt(this.MaxItemID + 1)) != location);
        return location;
    }

    protected int[] SampleItemPair(int i, int i2, int i3) {
        int intValue;
        int nextInt;
        int[] iArr = new int[2];
        if (this.fast_sampling) {
            intValue = this.user_pos_items.get(i).get(this.random.nextInt(this.user_pos_items.get(i).size())).intValue();
            nextInt = this.user_neg_items.get(i).get(this.random.nextInt(this.user_neg_items.get(i).size())).intValue();
        } else {
            List<Integer> location = GetFeedback().GetUserMatrix().getLocation(i);
            intValue = location.get(this.random.nextInt(location.size())).intValue();
            do {
                nextInt = this.random.nextInt(this.MaxItemID + 1);
            } while (GetFeedback().GetUserMatrix().getLocation(i, nextInt));
        }
        iArr[0] = intValue;
        iArr[1] = nextInt;
        return iArr;
    }

    protected int SampleUser() {
        while (true) {
            int nextInt = this.random.nextInt(this.MaxUserID + 1);
            List<Integer> location = GetFeedback().GetUserMatrix().getLocation(nextInt);
            if (location.size() != 0 && location.size() != this.MaxItemID + 1) {
                return nextInt;
            }
        }
    }

    protected int[] SampleTriple(int i, int i2, int i3) {
        int[] iArr = new int[2];
        int SampleUser = SampleUser();
        int[] SampleItemPair = SampleItemPair(SampleUser, i2, i3);
        return new int[]{SampleUser, SampleItemPair[0], SampleItemPair[1]};
    }

    protected void UpdateFactors(int i, int i2, int i3, boolean z, boolean z2, boolean z3) {
        double exp = 1.0d / (1.0d + Math.exp(Predict(i, i2) - Predict(i, i3)));
        if (z2) {
            double d = exp - (this.BiasReg * this.item_bias[i2]);
            double[] dArr = this.item_bias;
            dArr[i2] = dArr[i2] + (this.learn_rate * d);
        }
        if (z3) {
            double d2 = (-exp) - (this.BiasReg * this.item_bias[i3]);
            double[] dArr2 = this.item_bias;
            dArr2[i3] = dArr2[i3] + (this.learn_rate * d2);
        }
        for (int i4 = 0; i4 < this.num_factors; i4++) {
            double location = this.user_factors.getLocation(i, i4);
            double location2 = this.item_factors.getLocation(i2, i4);
            double location3 = this.item_factors.getLocation(i3, i4);
            if (z) {
                this.user_factors.setLocation(i, i4, location + (this.learn_rate * (((location2 - location3) * exp) - (this.reg_u * location))));
            }
            if (z2) {
                this.item_factors.setLocation(i2, i4, location2 + (this.learn_rate * ((location * exp) - (this.reg_i * location2))));
            }
            if (z3) {
                this.item_factors.setLocation(i3, i4, location3 + (this.learn_rate * (((-location) * exp) - (this.reg_j * location3))));
            }
        }
    }

    @Override // com.rapidminer.ItemRecommendation.ItemRecommender
    public void AddFeedback(int i, int i2) {
        super.AddFeedback(i, i2);
        if (this.fast_sampling) {
            CreateFastSamplingData(i);
        }
        RetrainUser(i);
        RetrainItem(i2);
    }

    @Override // com.rapidminer.ItemRecommendation.ItemRecommender
    public void RemoveFeedback(int i, int i2) {
        super.RemoveFeedback(i, i2);
        if (this.fast_sampling) {
            CreateFastSamplingData(i);
        }
        RetrainUser(i);
        RetrainItem(i2);
    }

    @Override // com.rapidminer.ItemRecommendation.ItemRecommender
    protected void AddUser(int i) {
    }

    @Override // com.rapidminer.ItemRecommendation.ItemRecommender
    public void AddUsers(List<Integer> list) {
        super.AddUsers(list);
        this.user_factors.AddRows(list.get(list.size() - 1).intValue() + 1);
        for (int i = 0; i < list.size(); i++) {
            MatrixUtils.RowInitNormal(this.user_factors, this.InitMean, this.InitStdev, list.get(i).intValue());
            if (this.fast_sampling) {
                CreateFastSamplingData(list.get(i).intValue());
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.rapidminer.ItemRecommendation.ItemRecommender
    public void AddItem(int i) {
    }

    @Override // com.rapidminer.ItemRecommendation.ItemRecommender
    public void AddItems(List<Integer> list) {
        super.AddItems(list);
        this.item_factors.AddRows(list.get(list.size() - 1).intValue() + 1);
        for (int i = 0; i < list.size(); i++) {
            MatrixUtils.RowInitNormal(this.item_factors, this.InitMean, this.InitStdev, list.get(i).intValue());
        }
    }

    @Override // com.rapidminer.ItemRecommendation.ItemRecommender
    public void RemoveUser(int i) {
        super.RemoveUser(i);
        if (this.fast_sampling) {
            this.user_pos_items.set(i, null);
            this.user_neg_items.set(i, null);
        }
        this.user_factors.SetRowToOneValue(i, 0.0d);
    }

    @Override // com.rapidminer.ItemRecommendation.ItemRecommender
    public void RemoveItem(int i) {
        super.RemoveItem(i);
        this.item_factors.SetRowToOneValue(i, 0.0d);
    }

    protected void RetrainUser(int i) {
        MatrixUtils.RowInitNormal(this.user_factors, this.InitMean, this.InitStdev, i);
        List<Integer> location = GetFeedback().GetUserMatrix().getLocation(i);
        for (int i2 = 0; i2 < location.size(); i2++) {
            SampleItemPair(i, 0, 0);
            UpdateFactors(i, 0, 0, true, false, false);
        }
    }

    @Override // com.rapidminer.ItemRecommendation.ItemRecommender
    public void RetrainUsers(List<Integer> list) {
        for (int i = 0; i < list.size(); i++) {
            MatrixUtils.RowInitNormal(this.user_factors, this.InitMean, this.InitStdev, list.get(i).intValue());
            List<Integer> location = GetFeedback().GetUserMatrix().getLocation(list.get(i).intValue());
            for (int i2 = 0; i2 < location.size(); i2++) {
                SampleItemPair(list.get(i).intValue(), 0, 0);
                UpdateFactors(list.get(i).intValue(), 0, 0, true, false, false);
            }
        }
    }

    protected void RetrainItem(int i) {
        MatrixUtils.RowInitNormal(this.item_factors, this.InitMean, this.InitStdev, i);
        int NumberOfEntries = GetFeedback().GetUserMatrix().NumberOfEntries() / (this.MaxItemID + 1);
        for (int i2 = 0; i2 < NumberOfEntries; i2++) {
            int SampleUser = SampleUser();
            if (SampleOtherItem(SampleUser, i, 0)) {
                UpdateFactors(SampleUser, i, 0, false, true, false);
            } else {
                UpdateFactors(SampleUser, 0, i, false, false, true);
            }
        }
    }

    @Override // com.rapidminer.ItemRecommendation.ItemRecommender
    public void RetrainItems(List<Integer> list) {
        double[] dArr = new double[this.MaxItemID + 1];
        for (int i = 0; i < this.item_bias.length; i++) {
            dArr[i] = this.item_bias[i];
        }
        for (int length = this.item_bias.length; length < dArr.length; length++) {
            dArr[length] = 0.0d;
        }
        this.item_bias = dArr;
        for (int i2 = 0; i2 < list.size(); i2++) {
            MatrixUtils.RowInitNormal(this.item_factors, this.InitMean, this.InitStdev, list.get(i2).intValue());
            int NumberOfEntries = GetFeedback().GetUserMatrix().NumberOfEntries() / (this.MaxItemID + 1);
            for (int i3 = 0; i3 < NumberOfEntries; i3++) {
                int SampleUser = SampleUser();
                if (SampleOtherItem(SampleUser, list.get(i2).intValue(), 0)) {
                    UpdateFactors(SampleUser, list.get(i2).intValue(), 0, false, true, false);
                } else {
                    UpdateFactors(SampleUser, 0, list.get(i2).intValue(), false, false, true);
                }
            }
        }
    }

    public double ComputeLoss() {
        double d = 0.0d;
        for (int i = 0; i < this.loss_sample_u.length; i++) {
            d += 1.0d / (1.0d + Math.exp(Predict(this.loss_sample_u[i], this.loss_sample_i[i]) - Predict(this.loss_sample_u[i], this.loss_sample_j[i])));
        }
        double d2 = 0.0d;
        for (int i2 = 0; i2 < this.loss_sample_u.length; i2++) {
            d2 = d2 + (GetRegU() * Math.pow(VectorUtils.EuclideanNorm(this.user_factors.GetRow(this.loss_sample_u[i2])), 2.0d)) + (GetRegI() * Math.pow(VectorUtils.EuclideanNorm(this.item_factors.GetRow(this.loss_sample_i[i2])), 2.0d)) + (GetRegJ() * Math.pow(VectorUtils.EuclideanNorm(this.item_factors.GetRow(this.loss_sample_j[i2])), 2.0d)) + (this.BiasReg * Math.pow(this.item_bias[this.loss_sample_i[i2]], 2.0d)) + (this.BiasReg * Math.pow(this.item_bias[this.loss_sample_j[i2]], 2.0d));
        }
        return d + (0.5d * d2);
    }

    @Override // com.rapidminer.ItemRecommendation.MF, com.rapidminer.RatingPrediction.IIterativeModel
    public double ComputeFit() {
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < this.MaxUserID + 1; i2++) {
            int size = GetFeedback().GetUserMatrix().getLocation(i2).size();
            if (size != 0) {
                int[] PredictItems = ItemPrediction.PredictItems(this, i2, this.MaxItemID);
                int i3 = ((this.MaxItemID + 1) - size) * size;
                int i4 = 0;
                int i5 = 0;
                for (int i6 : PredictItems) {
                    if (GetFeedback().GetUserMatrix().getLocation(i2, i6)) {
                        i5++;
                    } else {
                        i4 += i5;
                    }
                }
                d += i4 / i3;
                i++;
            }
        }
        return d / i;
    }

    private void CreateFastSamplingData(int i) {
        while (i >= this.user_pos_items.size()) {
            this.user_pos_items.add(null);
        }
        while (i >= this.user_neg_items.size()) {
            this.user_neg_items.add(null);
        }
        this.user_pos_items.set(i, new ArrayList(GetFeedback().GetUserMatrix().getLocation(i)));
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 <= this.MaxItemID; i2++) {
            if (!GetFeedback().GetUserMatrix().getLocation(i, i2)) {
                arrayList.add(Integer.valueOf(i2));
            }
        }
        this.user_neg_items.set(i, arrayList);
    }

    protected void CheckSampling() {
        if ((((this.MaxUserID + 1) * (this.MaxItemID + 1)) * 4) / 1048576 > this.fast_sampling_memory_limit || this.fast_sampling_memory_limit == 0) {
            return;
        }
        this.fast_sampling = true;
        this.user_pos_items = new ArrayList(this.MaxUserID + 1);
        this.user_neg_items = new ArrayList(this.MaxUserID + 1);
        for (int i = 0; i < this.MaxUserID + 1; i++) {
            CreateFastSamplingData(i);
        }
    }

    @Override // com.rapidminer.RatingPrediction.IIterativeModel
    public int GetNumIter() {
        return this.NumIter;
    }

    @Override // com.rapidminer.ItemRecommendation.MF, com.rapidminer.ItemRecommendation.ItemRecommender
    public double Predict(int i, int i2) {
        if (i < 0 || i >= this.user_factors.dim1) {
            System.out.println("user is unknown: " + i);
            return 0.0d;
        }
        if (i2 >= 0 && i2 < this.item_factors.dim1) {
            return this.item_bias[i2] + MatrixUtils.RowScalarProduct(this.user_factors, i, this.item_factors, i2);
        }
        System.out.println("item is unknown: " + i2);
        return 0.0d;
    }

    @Override // com.rapidminer.ItemRecommendation.ItemRecommender
    public void SaveModel(String str) {
    }

    @Override // com.rapidminer.ItemRecommendation.ItemRecommender
    public void LoadModel(String str) {
    }

    public String ToString() {
        return String.format("BPRMF num_factors={0} bias_reg={1} reg_u={2} reg_i={3} reg_j={4} num_iter={5} learn_rate={6} bold_driver={7} fast_sampling_memory_limit={8} init_mean={9} init_stdev={10}", Integer.valueOf(this.num_factors), Double.valueOf(this.BiasReg), Double.valueOf(this.reg_u), Double.valueOf(this.reg_i), Double.valueOf(this.reg_j), Integer.valueOf(this.NumIter), Double.valueOf(this.learn_rate), Boolean.valueOf(this.BoldDriver), Integer.valueOf(this.fast_sampling_memory_limit), Double.valueOf(this.InitMean), Double.valueOf(this.InitStdev));
    }

    public void setSource(String str) {
        this.source = str;
    }

    public String getSource() {
        return this.source;
    }

    public void appendOperatorToHistory(Operator operator, OutputPort outputPort) {
        if (this.processingHistory == null) {
            this.processingHistory = new LinkedList<>();
            if (operator.getProcess() != null) {
                this.processingHistory.add(new ProcessingStep(operator, outputPort));
            }
        }
        ProcessingStep processingStep = new ProcessingStep(operator, outputPort);
        if (operator.getProcess() != null) {
            if (this.processingHistory.isEmpty() || !this.processingHistory.getLast().equals(processingStep)) {
                this.processingHistory.add(processingStep);
            }
        }
    }

    public List<ProcessingStep> getProcessingHistory() {
        if (this.processingHistory == null) {
            this.processingHistory = new LinkedList<>();
        }
        return this.processingHistory;
    }

    public LoggingHandler getLog() {
        return this.loggingHandler != null ? this.loggingHandler : LogService.getGlobal();
    }

    public void setLoggingHandler(LoggingHandler loggingHandler) {
        this.loggingHandler = loggingHandler;
    }

    public IOObject copy() {
        return this;
    }

    protected void initWriting() {
    }

    public Annotations getAnnotations() {
        return new Annotations();
    }
}
