zoukankan      html  css  js  c++  java
  • [Scala] NDCG 的 Scala 实现

    一、关于 NDCG

    [LTR] 信息检索评价指标(RP/MAP/DCG/NDCG/RR/ERR)

    二、代码实现

    1、训练数据的加载解析

    import scala.io.Source
    
    /*
    * 训练行数据
    * */
    case class TrainDataRow(target: Int, qid: Int, features: Array[Double])
    
    object TrainDataRow {
      // 加载文件数据
      // 格式:
      // <line> .=. <target> qid:<qid> <feature>:<value> <feature>:<value> ... <feature>:<value> # <info>
      // <target> .=. <positive integer>
      // <qid> .=. <positive integer>
      // <feature> .=. <positive integer>
      // <value> .=. <float>
      // <info> .=. <string>
      def loadFile(file: String): List[TrainDataRow] = {
        Source.fromFile(file).getLines.toList.par.map(x => {
          val strArray = x.split(' ')
          val label = strArray(0).toInt
          val qid = strArray(1).split(':')(1).toInt
          val fValArray = strArray.drop(2).map(x => x.split(':')(1).toDouble)
          new TrainDataRow(label, qid, fValArray)
        }).toList
      }
    }

    2、NDCG 的实现

    object NDCG {
      /*
      * 计算 NDCG 分值
      * */
      def score(rows: List[TrainDataRow], k: Int): Double = {
        val size = k.min(rows.length - 1)
        // 理想 DCG
        var idealDcg: Double = 0
        val sortedList = rows.sortWith((x, y) => x.target > y.target)
        for (i <- 0 to size) {
          // 计算累计效益
          val gain = (1 << sortedList(i).target) - 1
          // 计算折扣因子
          val discount = 1.0 / (Math.log(i + 2) / Math.log(2))
          idealDcg += gain * discount
        }
        if (idealDcg > 0) {
          var dcg: Double = 0
          for (i <- 0 to size) {
            // 计算累计效益
            val gain = (1 << rows(i).target) - 1
            // 计算折扣因子
            val discount = 1.0 / (Math.log(i + 2) / Math.log(2))
            dcg += gain * discount
          }
          dcg / idealDcg
        }
        else 0
      }
    }

    3、训练数据集的 NDCG 计算

    def calcNDCG(trainDataFile: String, k: Int): Double = {
      println("开始计算...")
      val start = System.nanoTime()
      val data = TrainDataRow.loadFile(trainDataFile) // 加载训练数据文件
      println("数据量:" + data.length + ",用时:" + (System.nanoTime() - start) / 1000000 + " ms")
      val grpData: Map[Int, List[TrainDataRow]] = data.groupBy(_.qid) // 根据 qid 分组
      val resultNDCG = grpData.map(x => NDCG.score(x._2, k)).sum / grpData.size
      println(s"NDCG@$k: $resultNDCG")
      val end = System.nanoTime()
      println("计算运行时间:" + (end - start) / 1000000 + " ms")
      resultNDCG
    }

     

    by. Memento

  • 相关阅读:
    PDO事务处理不能保持一致性
    Android开发中的SQLite事务处理
    Mysql安装
    IIS下https配置及安全整改
    exchang2010OWA主界面添加修改密码选项
    查阅文件技巧
    RHEL yum
    CentOS之——CentOS7安装iptables防火墙
    Linux修改主机名称
    Vmware虚拟机设置静态IP地址
  • 原文地址:https://www.cnblogs.com/memento/p/8675800.html
Copyright © 2011-2022 走看看