zoukankan      html  css  js  c++  java
  • 因子分解机

    因子分解机

    Factorization Machines

    因子分解机(FM)[Rendle,2010]由Steffen Rendle于2010年提出,是一种可用于分类、回归和排序任务的监督算法。它很快就引起了人们的注意,并成为一种流行而有影响力的预测和推荐方法。特别地,它是线性回归模型和矩阵分解模型的推广。此外,它让人想起具有多项式核的支持向量机。与线性回归和矩阵分解相比,因式分解机的优势在于:(1)它可以建模X-变量交互作用,其中X是多项式阶数,通常设为2。(2) 与因子分解机相结合的快速优化算法可以将多项式的计算时间减少到线性复杂度,特别是对于高维稀疏输入,它是非常有效的。基于这些原因,因子分解机被广泛应用于现代广告和产品推荐中。技术细节和实现如下所述。

    1. 2-Way Factorization Machines

    一些特性交互很容易理解,因此可以由专家设计。然而,大多数其他特性交互都隐藏在数据中,难以识别。因此,特征交互的自动建模可以大大减少特征工程的工作量。很明显,前两项对应于线性回归模型,最后一项是矩阵分解模型的扩展。如果功能i表示项和功能,j表示一个用户,第三项正好是用户和项嵌入之间的点积。值得注意的是,FM也可以推广到更高阶(度>2)。然而,数值稳定性可能会削弱推广。

    2. An Efficient Optimization Criterion

    用直接的方法优化因子分解机会导致O(kd2),因为所有成对的相互作用都需要计算。为了解决这一效率低下的问题,我们可以对FM的第三项进行重组,这样可以大大降低计算成本,从而导致线性时间复杂度(O(kd)O(kd))。两两相互作用项的重新表述如下:

     通过这种重构,大大降低了模型的复杂度。此外,对于稀疏特征,只需要计算非零元素,这样整体复杂度就与非零特征的数量成线性关系。 

    为了学习FM模型,我们可以将MSE损失用于回归任务,交叉熵损失用于分类任务,BPR损失用于排名任务。标准优化器(如SGD和Adam)可用于优化。
    from d2l import mxnet as d2l
    from mxnet import init, gluon, np, npx
    from mxnet.gluon import nn
    import os
    import sys
    npx.set_np()
    3. Model Implementation
    下面的代码实现了因子分解机。很明显,FM由一个线性回归块和一个有效的特征交互块组成。由于我们将CTR预测视为一个分类任务,因此我们对最终得分应用了一个S形函数。
    class FM(nn.Block):
    def __init__(self, field_dims, num_factors):
    super(FM, self).__init__()
    num_inputs = int(sum(field_dims))
    self.embedding = nn.Embedding(num_inputs, num_factors)
    self.fc = nn.Embedding(num_inputs, 1)
    self.linear_layer = nn.Dense(1, use_bias=True)
    def forward(self, x):
    square_of_sum = np.sum(self.embedding(x), axis=1) ** 2
    sum_of_square = np.sum(self.embedding(x) ** 2, axis=1)
    x = self.linear_layer(self.fc(x).sum(1))
    + 0.5 * (square_of_sum - sum_of_square).sum(1, keepdims=True)
    x = npx.sigmoid(x)
    return x
    4. Load the Advertising Dataset
    我们使用最后一节中的CTR数据包装器来加载在线广告数据集。
    batch_size = 2048
    data_dir = d2l.download_extract('ctr')
    train_data = d2l.CTRDataset(os.path.join(data_dir, 'train.csv'))
    test_data = d2l.CTRDataset(os.path.join(data_dir, 'test.csv'),
    feat_mapper=train_data.feat_mapper,
    defaults=train_data.defaults)
    num_workers = 0 if sys.platform.startswith('win') else 4
    train_iter = gluon.data.DataLoader(
    train_data, shuffle=True, last_batch='rollover', batch_size=batch_size,
    num_workers=num_workers)
    test_iter = gluon.data.DataLoader(
    test_data, shuffle=False, last_batch='rollover', batch_size=batch_size,
    num_workers=num_workers)
    5. Train the Model
    然后,我们训练模型。默认情况下,学习率设置为0.01,嵌入大小设置为20。Adam优化器和SigmoidBinaryCrossEntropyLoss loss用于模型训练。
    ctx = d2l.try_all_gpus()
    net = FM(train_data.field_dims, num_factors=20)
    net.initialize(init.Xavier(), ctx=ctx)
    lr, num_epochs, optimizer = 0.02, 30, 'adam'
    trainer = gluon.Trainer(net.collect_params(), optimizer,
    {'learning_rate': lr})
    loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
    d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, ctx)
    loss 0.505, train acc 0.761, test acc 0.759
    151228.5 examples/sec on [gpu(0), gpu(1)]

    6. Summary

    · FM is a general framework that can be applied on a variety of tasks such as regression, classification, and ranking.
    · Feature interaction/crossing is important for prediction tasks and the 2-way interaction can be efficiently modeled with FM.

  • 相关阅读:
    Django的rest_framework的视图之基于通用类编写视图源码解析
    Django的rest_framework的视图之Mixin类编写视图源码解析
    Django1.0和2.0中的rest_framework的序列化组件之超链接字段的处理
    Django的restframework的序列化组件之对单条数据的处理
    Django2.0的path方法无法使用正则表达式的解决办法
    算法的时间复杂度和空间复杂度简单理解
    回归后端分页本质,理清后端分页思路
    SQL Server 2008R2向表中添加字段
    Asp.net IIS 服务器配置远程访问
    linux write 命令
  • 原文地址:https://www.cnblogs.com/wujianming-110117/p/13224850.html
Copyright © 2011-2022 走看看