package com.rapidminer.operator.RatingPrediction;

import com.rapidminer.RatingPrediction.IIterativeModel;
import com.rapidminer.data.Matrix;
import com.rapidminer.eval.RatingEval;
import com.rapidminer.matrixUtils.MatrixUtils;
import com.rapidminer.matrixUtils.VectorUtils;
import com.rapidminer.operator.Annotations;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.ProcessingStep;
import com.rapidminer.tools.LogService;
import com.rapidminer.tools.LoggingHandler;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:com/rapidminer/operator/RatingPrediction/MatrixFactorization.class */
public class MatrixFactorization extends RatingPredictor implements IIterativeModel {
    static final long serialVersionUID = 3453434;
    protected Matrix user_factors;
    protected Matrix item_factors;
    protected double global_bias;
    public double InitMean;
    private transient LoggingHandler loggingHandler;
    private String source = null;
    private transient LinkedList<ProcessingStep> processingHistory = new LinkedList<>();
    public double Regularization = 0.015d;
    public double LearnRate = 0.01d;
    public int NumIter = 30;
    public double InitStdev = 0.1d;
    public int NumFactors = 10;

    @Override // com.rapidminer.RatingPrediction.IIterativeModel
    public int GetNumIter() {
        return this.NumIter;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void InitModel() {
        super.InitModel();
        this.user_factors = new Matrix(GetRatings().GetMaxUserID() + 1, this.NumFactors);
        this.item_factors = new Matrix(GetRatings().GetMaxItemID() + 1, this.NumFactors);
        MatrixUtils.RowInitNormal(this.user_factors, this.InitMean, this.InitStdev);
        MatrixUtils.RowInitNormal(this.item_factors, this.InitMean, this.InitStdev);
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRecommender
    public void Train() {
        InitModel();
        this.global_bias = GetRatings().Average();
        LearnFactors(GetRatings().RandomIndex(), true, true);
    }

    public void Iterate() {
        Iterate(GetRatings().RandomIndex(), true, true);
    }

    public void RetrainUser(int i) {
        if (this.UpdateUsers) {
            MatrixUtils.RowInitNormal(this.user_factors, this.InitMean, this.InitStdev, i);
            LearnFactors(GetRatings().ByUser().get(i), true, false);
        }
    }

    public void RetrainItem(int i) {
        if (this.UpdateItems) {
            MatrixUtils.RowInitNormal(this.item_factors, this.InitMean, this.InitStdev, i);
            LearnFactors(GetRatings().ByItem().get(i), false, true);
        }
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor
    public int AddRatings(List<Integer> list, List<Integer> list2, List<Double> list3) {
        if (list == null) {
            return 1;
        }
        super.AddRatings(list, list2, list3);
        return 1;
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void RetrainUsers(List<Integer> list) {
        for (int i = 0; i < list.size(); i++) {
            RetrainUser(list.get(i).intValue());
        }
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void RetrainItems(List<Integer> list) {
        for (int i = 0; i < list.size(); i++) {
            RetrainItem(list.get(i).intValue());
        }
    }

    protected void Iterate(List<Integer> list, boolean z, boolean z2) {
        for (int i = 0; i < list.size(); i++) {
            int intValue = list.get(i).intValue();
            int intValue2 = this.ratings.GetUsers().get(intValue).intValue();
            int intValue3 = this.ratings.GetItems().get(intValue).intValue();
            double GetValues = this.ratings.GetValues(intValue) - Predict(intValue2, intValue3, false);
            for (int i2 = 0; i2 < this.NumFactors; i2++) {
                double location = this.user_factors.getLocation(intValue2, i2);
                double location2 = this.item_factors.getLocation(intValue3, i2);
                double d = (GetValues * location2) - (this.Regularization * location);
                double d2 = (GetValues * location) - (this.Regularization * location2);
                if (z) {
                    MatrixUtils.Inc(this.user_factors, intValue2, i2, this.LearnRate * d);
                }
                if (z2) {
                    MatrixUtils.Inc(this.item_factors, intValue3, i2, this.LearnRate * d2);
                }
            }
        }
    }

    private void LearnFactors(List<Integer> list, boolean z, boolean z2) {
        for (int i = 0; i < this.NumIter; i++) {
            Iterate(list, z, z2);
        }
    }

    protected double Predict(int i, int i2, boolean z) {
        double RowScalarProduct = this.global_bias + MatrixUtils.RowScalarProduct(this.user_factors, i, this.item_factors, i2);
        if (z) {
            if (RowScalarProduct > GetMaxRating()) {
                return GetMaxRating();
            }
            if (RowScalarProduct < GetMinRating()) {
                return GetMinRating();
            }
        }
        return RowScalarProduct;
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRecommender
    public double Predict(int i, int i2) {
        if (i < this.user_factors.dim1 && i2 < this.item_factors.dim1) {
            return Predict(i, i2, true);
        }
        return this.global_bias;
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRatingPredictor
    public void AddRating(int i, int i2, double d) {
        super.AddRating(i, i2, d);
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRatingPredictor
    public void UpdateRating(int i, int i2, double d) {
        super.UpdateRating(i, i2, d);
        RetrainUser(i);
        RetrainItem(i2);
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRatingPredictor
    public void RemoveRating(int i, int i2) {
        super.RemoveRating(i, i2);
        RetrainUser(i);
        RetrainItem(i2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void AddUser(int i) {
        super.AddUser(i);
        this.user_factors.AddRows(i + 1);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void AddItem(int i) {
        super.AddItem(i);
        this.item_factors.AddRows(i + 1);
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void AddUsers(List<Integer> list) {
        super.AddUsers(list);
        this.user_factors.AddRows(list.get(list.size() - 1).intValue() + 1);
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void AddItems(List<Integer> list) {
        super.AddItems(list);
        this.item_factors.AddRows(list.get(list.size() - 1).intValue() + 1);
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRatingPredictor
    public void RemoveUser(int i) {
        super.RemoveUser(i);
        this.user_factors.SetRowToOneValue(i, 0.0d);
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRatingPredictor
    public void RemoveItem(int i) {
        super.RemoveItem(i);
        this.item_factors.SetRowToOneValue(i, 0.0d);
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRecommender
    public void SaveModel(String str) {
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRecommender
    public void LoadModel(String str) {
    }

    @Override // com.rapidminer.RatingPrediction.IIterativeModel
    public double ComputeFit() {
        return RatingEval.Evaluate(this, this.ratings).get("RMSE").doubleValue();
    }

    public double ComputeLoss() {
        double d = 0.0d;
        for (int i = 0; i < this.ratings.Count(); i++) {
            d += Math.pow(Predict(this.ratings.GetUsers().get(i).intValue(), this.ratings.GetItems().get(i).intValue()) - this.ratings.GetValues(i), 2.0d);
        }
        for (int i2 = 0; i2 <= this.MaxUserID; i2++) {
            d += this.ratings.CountByUser()[i2] * this.Regularization * Math.pow(VectorUtils.EuclideanNorm(this.user_factors.GetRow(i2)), 2.0d);
        }
        for (int i3 = 0; i3 <= this.MaxItemID; i3++) {
            d += this.ratings.CountByItem()[i3] * this.Regularization * Math.pow(VectorUtils.EuclideanNorm(this.item_factors.GetRow(i3)), 2.0d);
        }
        return d;
    }

    public String ToString() {
        return String.format("MatrixFactorization num_factors={0} regularization={1} learn_rate={2} num_iter={3} init_mean={4} init_stdev={5}", Integer.valueOf(this.NumFactors), Double.valueOf(this.Regularization), Double.valueOf(this.LearnRate), Integer.valueOf(this.NumIter), Double.valueOf(this.InitMean), Double.valueOf(this.InitStdev));
    }

    public void setSource(String str) {
        this.source = str;
    }

    public String getSource() {
        return this.source;
    }

    public void appendOperatorToHistory(Operator operator, OutputPort outputPort) {
        if (this.processingHistory == null) {
            this.processingHistory = new LinkedList<>();
            if (operator.getProcess() != null) {
                this.processingHistory.add(new ProcessingStep(operator, outputPort));
            }
        }
        ProcessingStep processingStep = new ProcessingStep(operator, outputPort);
        if (operator.getProcess() != null) {
            if (this.processingHistory.isEmpty() || !this.processingHistory.getLast().equals(processingStep)) {
                this.processingHistory.add(processingStep);
            }
        }
    }

    public List<ProcessingStep> getProcessingHistory() {
        if (this.processingHistory == null) {
            this.processingHistory = new LinkedList<>();
        }
        return this.processingHistory;
    }

    public LoggingHandler getLog() {
        return this.loggingHandler != null ? this.loggingHandler : LogService.getGlobal();
    }

    public void setLoggingHandler(LoggingHandler loggingHandler) {
        this.loggingHandler = loggingHandler;
    }

    public IOObject copy() {
        return this;
    }

    protected void initWriting() {
    }

    public Annotations getAnnotations() {
        return new Annotations();
    }
}
