swing 算法的简介

4,614次阅读
没有评论

共计 3374 个字符,预计需要花费 9 分钟才能阅读完成。

又是一年年底,年终述职总结刚刚完成了,这一年有不足也有进步,新的一年继续加油。先从整理资料开始吧!

之前的工作在排序那块花费的时间还是很多,而且一般的场景都是召回和精排都是分开的,所以各自迭代自己的算法,上一周做召回的同事做了一下现有召回方案的分享,趁着机会自己也学习学习。

这篇文章 swing 是 i2i 算法中一个,基于图的思想,不过算法本身比较简单,而且乍看一下跟CF还有点像。

开局一张图来看下swing吧

swing 算法的简介

这张图就是同时购买芒果和西瓜的一个图,同时购买这物品的有三个人。

实现 i2i 的链接体肯定就是用户,也就是这样实现物料之间的相似度的判定。

相似度的计算公式如下所示:

sim(i, j) = \sum_{u \in U_i \cap U_j} \sum_{u \in U_i \cap U_j} \frac{1}{\alpha + |I_u \cap I_v|}

上式中U_i表示物料i对应的用户集合,I_u表示用户u对应的物料集合,所以U_i \cap U_j 表示同事购买的<i,j>物品的用户交集,I_u \cap I_v 表示用户<u,v>购买的物品交集数量。

所以从上面的公式可以这么理解:

如果分母变大,那么就是用户之间购买的物品交集越多,那么这两个物品之间的相似度变低,反向来看,用户购买的东西比较少,但是两个物品出现的频次与之前相同,那么变相的来看这两个物品之间的相似度应该更高。

所以可以一定程度上说买东西太多了,各个物品之间的相似度一定程度上会被稀释。

代码是从别人那边借鉴过来的,原始链接

数据的输入格式是 (user_id,item_id)

/**
  * @ClassName: Swing
  * @Description: 实现Swing算法
  * @author: Thinkgamer
  **/

class SwingModel(spark: SparkSession) extends Serializable{
    var alpha: Option[Double] = Option(0.0)
    var items: Option[ArrayBuffer[String]] = Option(new ArrayBuffer[String]())
    var userIntersectionMap: Option[Map[String, Map[String, Int]]] = Option(Map[String, Map[String, Int]]())

    /*
     * @Description 给参数 alpha赋值
     * @Param double
     * @return cf.SwingModel
     **/
    def setAlpha(alpha: Double): SwingModel = {
        this.alpha = Option(alpha)
        this
    }

    /*
     * @Description 给所有的item进行赋值
     * @Param [array]
     * @return cf.SwingModel
     **/
    def setAllItems(array: Array[String]): SwingModel = {
        this.items = Option(array.toBuffer.asInstanceOf[ArrayBuffer[String]])
        this
    }

    /*
     * @Description 获取两两用户有行为的item交集个数
     * @Param [spark, data]
     * @return scala.collection.immutable.Map<java.lang.String,scala.collection.immutable.Map<java.lang.String,java.lang.Object>>
     **/
    def calUserRateItemIntersection(data: RDD[(String, String, Double)]): Map[String, Map[String, Int]] = {
        val rdd = data.map(l => (l._1, l._2)).groupByKey().map(l => (l._1, l._2.toSet))
        val map = (rdd cartesian rdd).map(l => (l._1._1, (l._2._1, (l._1._2 & l._2._2).toArray.length)))
            .groupByKey()
            .map(l => (l._1, l._2.toMap))
            .collectAsMap().toMap
        map.take(10).foreach(println)
        map
    }

    def fit(data: RDD[(String, String, Double)]): RDD[(String, String, Double)]= {
        this.userIntersectionMap = Option(this.calUserRateItemIntersection(data))
        println(this.userIntersectionMap.take(10))

        // (item,user_set)
        val rdd = data.map(l => (l._2, l._1)).groupByKey().map(l => (l._1, l._2.toSet))


        val result: RDD[(String, String, Double)] = (rdd cartesian rdd).map(l => {
            val item1 = l._1._1
            val item2 = l._2._1
            // intersectionUsers 是任意两个物品对应用户的交集用户
            val intersectionUsers = l._1._2 & l._2._2  
            var score = 0.0
            for(u1 <- intersectionUsers){
                for(u2 <- intersectionUsers){
                    score += 1.0 / (this.userIntersectionMap.get.get(u1).get(u2).toDouble + this.alpha.get)
                }
            }
            (item1, item2, score) // (item1, item2, swingsocre)
        })
        result
    }

    def evalute(test: RDD[(String, String, Double)]) = { }

    def predict(userid: String) = { }

    def predict(userids: Array[String]) = { }

}

object Swing {
    def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder().master("local[10]").appName("Swing").enableHiveSupport().getOrCreate()
        Logger.getRootLogger.setLevel(Level.WARN)

        val trainDataPath = "data/ml-100k/ua.base"
        val testDataPath = "data/ml-100k/ua.test"

       import spark.sqlContext.implicits._
        val train: RDD[(String, String, Double)] = spark.sparkContext.textFile(trainDataPath).map(_.split("\t")).map(l => (l(0), l(1), l(2).toDouble))
        val test: RDD[(String, String, Double)] = spark.sparkContext.textFile(testDataPath).map(_.split("\t")).map(l => (l(0), l(1), l(2).toDouble))

        val items: Array[String] = train.map(_._2).collect()

        val swing = new SwingModel(spark).setAlpha(1).setAllItems(items)
        val itemSims: RDD[(String, String, Double)] = swing.fit(train)

        swing.evalute(test)
        swing.predict("")
        swing.predict(Array("", ""))

        spark.close()
    }
}
正文完
请博主喝杯咖啡吧!
post-qrcode
 
admin
版权声明:本站原创文章,由 admin 2022-01-16发表,共计3374字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)
验证码