package com.rapidminer.operator.RatingPrediction;

import com.rapidminer.matrixUtils.MatrixUtils;
import com.rapidminer.matrixUtils.VectorUtils;
import java.util.List;

/* loaded from: input_file:com/rapidminer/operator/RatingPrediction/BiasedMatrixFactorization.class */
public class BiasedMatrixFactorization extends MatrixFactorization {
    static final long serialVersionUID = 3453434;
    public double RegU;
    public double RegI;
    public boolean BoldDriver;
    protected double[] user_bias;
    protected double[] item_bias;
    double last_loss = Double.NEGATIVE_INFINITY;
    public double BiasReg = 1.0E-4d;

    public void SetRegularization(double d) {
        this.Regularization = d;
        this.RegU = d;
        this.RegI = d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void InitModel() {
        super.InitModel();
        this.user_bias = new double[this.MaxUserID + 1];
        for (int i = 0; i <= this.MaxUserID; i++) {
            this.user_bias[i] = 0.0d;
        }
        this.item_bias = new double[this.MaxItemID + 1];
        for (int i2 = 0; i2 <= this.MaxItemID; i2++) {
            this.item_bias[i2] = 0.0d;
        }
        if (this.BoldDriver) {
            this.last_loss = ComputeLoss();
        }
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRecommender
    public void Train() {
        InitModel();
        this.global_bias = this.ratings.Average();
        for (int i = 0; i < this.NumIter; i++) {
            Iterate();
        }
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.RatingPrediction.IIterativeModel
    public void Iterate() {
        super.Iterate();
        if (this.BoldDriver) {
            double ComputeLoss = ComputeLoss();
            if (ComputeLoss > this.last_loss) {
                this.LearnRate *= 0.5d;
            } else if (ComputeLoss < this.last_loss) {
                this.LearnRate *= 1.05d;
            }
            this.last_loss = ComputeLoss;
        }
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization
    protected void Iterate(List<Integer> list, boolean z, boolean z2) {
        double GetMaxRating = GetMaxRating() - GetMinRating();
        for (int i = 0; i < list.size(); i++) {
            int intValue = list.get(i).intValue();
            int intValue2 = this.ratings.GetUsers().get(intValue).intValue();
            int intValue3 = this.ratings.GetItems().get(intValue).intValue();
            double exp = 1.0d / (1.0d + Math.exp(-((this.user_bias[intValue2] + this.item_bias[intValue3]) + MatrixUtils.RowScalarProduct(this.user_factors, intValue2, this.item_factors, intValue3))));
            double GetValues = (this.ratings.GetValues(intValue) - (GetMinRating() + (exp * GetMaxRating))) * exp * (1.0d - exp) * GetMaxRating;
            if (z) {
                double[] dArr = this.user_bias;
                dArr[intValue2] = dArr[intValue2] + (this.LearnRate * (GetValues - (this.BiasReg * this.user_bias[intValue2])));
            }
            if (z2) {
                double[] dArr2 = this.item_bias;
                dArr2[intValue3] = dArr2[intValue3] + (this.LearnRate * (GetValues - (this.BiasReg * this.item_bias[intValue3])));
            }
            for (int i2 = 0; i2 < this.NumFactors; i2++) {
                double location = this.user_factors.getLocation(intValue2, i2);
                double location2 = this.item_factors.getLocation(intValue3, i2);
                if (z) {
                    MatrixUtils.Inc(this.user_factors, intValue2, i2, this.LearnRate * ((GetValues * location2) - (this.RegU * location)));
                }
                if (z2) {
                    MatrixUtils.Inc(this.item_factors, intValue3, i2, this.LearnRate * ((GetValues * location) - (this.RegI * location2)));
                }
            }
        }
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRecommender
    public double Predict(int i, int i2) {
        if (i >= this.user_factors.dim1 || i2 >= this.item_factors.dim1) {
            return this.global_bias;
        }
        return GetMinRating() + ((1.0d / (1.0d + Math.exp(-((this.user_bias[i] + this.item_bias[i2]) + MatrixUtils.RowScalarProduct(this.user_factors, i, this.item_factors, i2))))) * (GetMaxRating() - GetMinRating()));
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRecommender
    public void SaveModel(String str) {
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRecommender
    public void LoadModel(String str) {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void AddUser(int i) {
        super.AddUser(i);
        double[] dArr = new double[i + 1];
        System.arraycopy(this.user_bias, 0, dArr, 0, this.user_bias.length);
        this.user_bias = dArr;
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void AddUsers(List<Integer> list) {
        super.AddUsers(list);
        double[] dArr = new double[list.get(list.size() - 1).intValue() + 1];
        System.arraycopy(this.user_bias, 0, dArr, 0, this.user_bias.length);
        this.user_bias = dArr;
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void AddItems(List<Integer> list) {
        super.AddItems(list);
        double[] dArr = new double[list.get(list.size() - 1).intValue() + 1];
        System.arraycopy(this.item_bias, 0, dArr, 0, this.item_bias.length);
        this.item_bias = dArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void AddItem(int i) {
        super.AddItem(i);
        double[] dArr = new double[i + 1];
        System.arraycopy(this.item_bias, 0, dArr, 0, this.item_bias.length);
        this.item_bias = dArr;
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization
    public void RetrainUser(int i) {
        this.user_bias[i] = 0.0d;
        super.RetrainUser(i);
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization
    public void RetrainItem(int i) {
        this.item_bias[i] = 0.0d;
        super.RetrainItem(i);
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void RetrainUsers(List<Integer> list) {
        for (int i = 0; i < list.size(); i++) {
            RetrainUser(list.get(i).intValue());
        }
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void RetrainItems(List<Integer> list) {
        for (int i = 0; i < list.size(); i++) {
            RetrainItem(list.get(i).intValue());
        }
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.operator.RatingPrediction.RatingPredictor
    public int AddRatings(List<Integer> list, List<Integer> list2, List<Double> list3) {
        if (list == null) {
            return 1;
        }
        super.AddRatings(list, list2, list3);
        return 1;
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRatingPredictor
    public void RemoveUser(int i) {
        super.RemoveUser(i);
        this.user_bias[i] = 0.0d;
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRatingPredictor
    public void RemoveItem(int i) {
        super.RemoveItem(i);
        this.item_bias[i] = 0.0d;
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization
    public double ComputeLoss() {
        double d = 0.0d;
        for (int i = 0; i < this.ratings.Count(); i++) {
            d += Math.pow(Predict(this.ratings.GetUsers().get(i).intValue(), this.ratings.GetItems().get(i).intValue()) - this.ratings.GetValues(i), 2.0d);
        }
        double d2 = 0.0d;
        for (int i2 = 0; i2 <= this.MaxUserID; i2++) {
            d2 = d2 + (this.ratings.CountByUser()[i2] * this.RegU * Math.pow(VectorUtils.EuclideanNorm(this.user_factors.GetRow(i2)), 2.0d)) + (this.ratings.CountByUser()[i2] * this.BiasReg * Math.pow(this.user_bias[i2], 2.0d));
        }
        for (int i3 = 0; i3 <= this.MaxItemID; i3++) {
            d2 = d2 + (this.ratings.CountByItem()[i3] * this.RegI * Math.pow(VectorUtils.EuclideanNorm(this.item_factors.GetRow(i3)), 2.0d)) + (this.ratings.CountByItem()[i3] * this.BiasReg * Math.pow(this.item_bias[i3], 2.0d));
        }
        return d + d2;
    }

    @Override // com.rapidminer.operator.RatingPrediction.MatrixFactorization, com.rapidminer.operator.RatingPrediction.IRecommender
    public String ToString() {
        return String.format("BiasedMatrixFactorization num_factors={0} bias_reg={1} reg_u={2} reg_i={3} learn_rate={4} num_iter={5} bold_driver={6} init_mean={7} init_stdev={8}", Integer.valueOf(this.NumFactors), Double.valueOf(this.BiasReg), Double.valueOf(this.RegU), Double.valueOf(this.RegI), Double.valueOf(this.LearnRate), Integer.valueOf(this.NumIter), Boolean.valueOf(this.BoldDriver), Double.valueOf(this.InitMean), Double.valueOf(this.InitStdev));
    }
}
