可以考虑使用spark.ml里的ALS模型替换掉mllib中的ALS模型,因为spark.ml中的ALS模型基于数据框,具有更好的性能表现。同时,可以对数据进行缓存,以减少数据读取次数,加快计算速度。
代码示例:
from pyspark.ml.recommendation import ALS
from pyspark.sql.functions import when
from pyspark.ml.evaluation import RegressionEvaluator
# 加载数据
data = spark.read.csv("path/to/data.csv", header=True, inferSchema=True)
# 添加rating列
data = data.withColumn("rating", when(data["Rating"] > 3, 1).otherwise(0))
# 将数据集划分为训练集和测试集
(training, test) = data.randomSplit([0.8, 0.2])
# 对数据集进行缓存
training.cache()
test.cache()
# 使用ALS模型进行训练
als = ALS(rank=10, maxIter=10, regParam=0.1, userCol="UserID", itemCol="ProductID", ratingCol="rating")
model = als.fit(training)
# 对测试集进行预测
predictions = model.transform(test)
# 评估模型性能
evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")
rmse = evaluator.evaluate(predictions)
# 打印rmse结果
print("Root-mean-square error = " + str(rmse))