package com.rapidminer.operator.RatingPrediction;

import com.rapidminer.RatingPrediction.IIterativeModel;
import com.rapidminer.RatingPrediction.UserItemBaseline;
import com.rapidminer.data.Matrix;
import com.rapidminer.eval.RatingEval;
import com.rapidminer.matrixUtils.MatrixUtils;
import com.rapidminer.operator.Annotations;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.ProcessingStep;
import com.rapidminer.tools.LogService;
import com.rapidminer.tools.LoggingHandler;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:com/rapidminer/operator/RatingPrediction/FactorWisedMatrixFactorization.class */
public class FactorWisedMatrixFactorization extends RatingPredictor implements IIterativeModel {
    static final long serialVersionUID = 3453434;
    Matrix user_factors;
    Matrix item_factors;
    double global_bias;
    int res_old_size;
    int num_learned_factors;
    double[] residuals;
    public double InitMean;
    List<Integer> new_items;
    List<Integer> new_users;
    private transient LoggingHandler loggingHandler;
    UserItemBaseline global_effects = new UserItemBaseline();
    private String source = null;
    private transient LinkedList<ProcessingStep> processingHistory = new LinkedList<>();
    public double Shrinkage = 25.0d;
    public int NumFactors = 10;
    public int NumIter = 10;
    public double Sensibility = 1.0E-5d;
    public double InitStdev = 0.1d;

    @Override // com.rapidminer.RatingPrediction.IIterativeModel
    public int GetNumIter() {
        return this.NumIter;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void InitModel() {
        super.InitModel();
        this.user_factors = new Matrix(GetRatings().GetMaxUserID() + 1, this.NumFactors);
        this.item_factors = new Matrix(GetRatings().GetMaxItemID() + 1, this.NumFactors);
        this.global_effects.SetRatings(GetRatings());
        this.global_effects.SetMinRating(GetMinRating());
        this.global_effects.SetMaxRating(GetMaxRating());
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRecommender
    public void Train() {
        InitModel();
        this.global_effects.Train();
        this.global_bias = GetRatings().Average();
        this.residuals = new double[GetRatings().Count()];
        this.num_learned_factors = 0;
        for (int i = 0; i < this.NumIter; i++) {
            Iterate();
        }
    }

    @Override // com.rapidminer.RatingPrediction.IIterativeModel
    public void Iterate() {
        if (this.num_learned_factors >= this.NumFactors) {
            return;
        }
        for (int i = 0; i < GetRatings().Count(); i++) {
            int intValue = GetRatings().GetUsers().get(i).intValue();
            int intValue2 = GetRatings().GetItems().get(i).intValue();
            this.residuals[i] = GetRatings().GetValues(i) - Predict(intValue, intValue2);
            int min = Math.min(GetRatings().ByUser().get(intValue).size(), GetRatings().ByItem().get(intValue2).size());
            double[] dArr = this.residuals;
            int i2 = i;
            dArr[i2] = dArr[i2] * (min / (min + this.Shrinkage));
        }
        MatrixUtils.ColumnInitNormal(this.user_factors, this.InitMean, this.InitStdev, this.num_learned_factors);
        MatrixUtils.ColumnInitNormal(this.item_factors, this.InitMean, this.InitStdev, this.num_learned_factors);
        double d = 8.988465674311579E307d;
        double d2 = Double.MAX_VALUE;
        while (d / d2 < 1.0d - this.Sensibility) {
            double[] dArr2 = new double[this.MaxUserID + 1];
            double[] dArr3 = new double[this.MaxUserID + 1];
            for (int i3 = 0; i3 < GetRatings().Count(); i3++) {
                int intValue3 = GetRatings().GetUsers().get(i3).intValue();
                int intValue4 = GetRatings().GetItems().get(i3).intValue();
                dArr2[intValue3] = dArr2[intValue3] + (this.residuals[i3] * this.item_factors.getLocation(intValue4, this.num_learned_factors));
                dArr3[intValue3] = dArr3[intValue3] + (this.item_factors.getLocation(intValue4, this.num_learned_factors) * this.item_factors.getLocation(intValue4, this.num_learned_factors));
            }
            for (int i4 = 0; i4 <= this.MaxUserID; i4++) {
                if (dArr2[i4] != 0.0d) {
                    this.user_factors.setLocation(i4, this.num_learned_factors, dArr2[i4] / dArr3[i4]);
                }
            }
            double[] dArr4 = new double[this.MaxItemID + 1];
            double[] dArr5 = new double[this.MaxItemID + 1];
            for (int i5 = 0; i5 < GetRatings().Count(); i5++) {
                int intValue5 = GetRatings().GetUsers().get(i5).intValue();
                int intValue6 = GetRatings().GetItems().get(i5).intValue();
                dArr4[intValue6] = dArr4[intValue6] + (this.residuals[i5] * this.user_factors.getLocation(intValue5, this.num_learned_factors));
                dArr5[intValue6] = dArr5[intValue6] + (this.user_factors.getLocation(intValue5, this.num_learned_factors) * this.user_factors.getLocation(intValue5, this.num_learned_factors));
            }
            for (int i6 = 0; i6 <= this.MaxItemID; i6++) {
                if (dArr4[i6] != 0.0d) {
                    this.item_factors.setLocation(i6, this.num_learned_factors, dArr4[i6] / dArr5[i6]);
                }
            }
            d2 = d;
            d = ComputeFit();
        }
        this.num_learned_factors++;
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRecommender
    public double Predict(int i, int i2) {
        if (i < this.user_factors.dim1 && i2 < this.item_factors.dim1) {
            double Predict = this.global_effects.Predict(i, i2) + MatrixUtils.RowScalarProduct(this.user_factors, i, this.item_factors, i2);
            return Predict > GetMaxRating() ? GetMaxRating() : Predict < GetMinRating() ? GetMinRating() : Predict;
        }
        return this.global_bias;
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void AddItems(List<Integer> list) {
        super.AddItems(list);
        this.global_effects.AddItems(list);
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void AddUsers(List<Integer> list) {
        super.AddUsers(list);
        this.global_effects.AddUsers(list);
        int i = this.user_factors.dim1;
        this.user_factors.AddRows(this.MaxUserID + 1);
        System.out.println(this.MaxUserID + 1);
        System.out.println(i);
        this.new_users = list;
        MatrixUtils.RowInitNormal(this.user_factors, i, this.MaxUserID + 1, this.InitMean, this.InitStdev);
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor
    public int AddRatings(List<Integer> list, List<Integer> list2, List<Double> list3) {
        if (list == null) {
            return 1;
        }
        super.AddRatings(list, list2, list3);
        int i = this.item_factors.dim1;
        this.new_items = list2;
        this.item_factors.AddRows(GetRatings().GetMaxItemID() + 1);
        this.res_old_size = this.residuals.length;
        this.global_effects.SetRatings(GetRatings());
        double[] dArr = new double[GetRatings().Count()];
        for (int i2 = 0; i2 < this.residuals.length; i2++) {
            dArr[i2] = this.residuals.length;
        }
        this.residuals = dArr;
        MatrixUtils.RowInitNormal(this.item_factors, i, GetRatings().GetMaxItemID() + 1, this.InitMean, this.InitStdev);
        for (int Count = (GetRatings().Count() - this.res_old_size) - 1; Count < GetRatings().Count(); Count++) {
            int intValue = GetRatings().GetUsers().get(Count).intValue();
            int intValue2 = GetRatings().GetItems().get(Count).intValue();
            this.residuals[Count] = GetRatings().GetValues(Count) - Predict(intValue, intValue2);
            int min = Math.min(GetRatings().ByUser().get(intValue).size(), GetRatings().ByItem().get(intValue2).size());
            double[] dArr2 = this.residuals;
            int i3 = Count;
            dArr2[i3] = dArr2[i3] * (min / (min + this.Shrinkage));
        }
        return 1;
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void RetrainItems(List<Integer> list) {
        super.RetrainItems(list);
        this.global_effects.RetrainItems(list);
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor
    public void RetrainUsers(List<Integer> list) {
        super.RetrainUsers(list);
        this.global_effects.RetrainUsers(list);
        this.global_bias = GetRatings().Average();
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRecommender
    public void SaveModel(String str) {
    }

    @Override // com.rapidminer.operator.RatingPrediction.RatingPredictor, com.rapidminer.operator.RatingPrediction.IRecommender
    public void LoadModel(String str) {
    }

    @Override // com.rapidminer.RatingPrediction.IIterativeModel
    public double ComputeFit() {
        return RatingEval.Evaluate(this, this.ratings).get("RMSE").doubleValue();
    }

    @Override // com.rapidminer.operator.RatingPrediction.IRecommender
    public String ToString() {
        return String.format("FactorWiseMatrixFactorization num_factors={0} shrinkage={1} sensibility={2}  init_mean={3} init_stdev={4} num_iter={5}", Integer.valueOf(this.NumFactors), Double.valueOf(this.Shrinkage), Double.valueOf(this.Sensibility), Double.valueOf(this.InitMean), Double.valueOf(this.InitStdev), Integer.valueOf(this.NumIter));
    }

    public void setSource(String str) {
        this.source = str;
    }

    public String getSource() {
        return this.source;
    }

    public void appendOperatorToHistory(Operator operator, OutputPort outputPort) {
        if (this.processingHistory == null) {
            this.processingHistory = new LinkedList<>();
            if (operator.getProcess() != null) {
                this.processingHistory.add(new ProcessingStep(operator, outputPort));
            }
        }
        ProcessingStep processingStep = new ProcessingStep(operator, outputPort);
        if (operator.getProcess() != null) {
            if (this.processingHistory.isEmpty() || !this.processingHistory.getLast().equals(processingStep)) {
                this.processingHistory.add(processingStep);
            }
        }
    }

    public List<ProcessingStep> getProcessingHistory() {
        if (this.processingHistory == null) {
            this.processingHistory = new LinkedList<>();
        }
        return this.processingHistory;
    }

    public LoggingHandler getLog() {
        return this.loggingHandler != null ? this.loggingHandler : LogService.getGlobal();
    }

    public void setLoggingHandler(LoggingHandler loggingHandler) {
        this.loggingHandler = loggingHandler;
    }

    public IOObject copy() {
        return this;
    }

    protected void initWriting() {
    }

    public Annotations getAnnotations() {
        return new Annotations();
    }
}
