zoukankan      html  css  js  c++  java
  • surprise库官方文档分析(三):搭建自己的预测算法

    1、基础

    创建自己的预测算法非常简单:算法只不过是一个派生自AlgoBase具有estimate 方法的类。这是该方法调用的predict()方法。它接受内部用户ID,内部项ID,并返回估计评级r

    from surprise import AlgoBase
    from surprise import Dataset
    from surprise.model_selection import cross_validate
    
    
    class MyOwnAlgorithm(AlgoBase):
    
        def __init__(self):
    
            # Always call base method before doing anything.
            AlgoBase.__init__(self)
    
        def estimate(self, u, i):
            #   存储有关预测的其他信息,还可以返回包含给定详细信息的字典
            details = {'info1' : 'That was',
                       'info2' : 'easy stuff :)'}
            return 3, details
    
    
    data = Dataset.load_builtin('ml-100k')
    algo = MyOwnAlgorithm()
    
    cross_validate(algo, data, verbose=True)

    以上代码实现了一个最简单的自定义预测方法。

    2、fit()方法

    现在,让我们制作一个稍微聪明的算法来预测列车集的所有评级的平均值。由于这是一个不依赖于当前用户或项目的常量值,我们宁愿一劳永逸地计算它。这可以通过定义fit方法来完成

    class MyOwnAlgorithm(AlgoBase):
    
        def __init__(self):
    
            # Always call base method before doing anything.
            AlgoBase.__init__(self)
    
        def fit(self, trainset):
    
            # Here again: call base method before doing anything.
            AlgoBase.fit(self, trainset)
    
            # Compute the average rating. We might as well use the
            # trainset.global_mean attribute ;)
            self.the_mean = np.mean([r for (_, _, r) in
                                     self.trainset.all_ratings()])
    
            return self
    
        def estimate(self, u, i):
    
            return self.the_mean

    fit方法例如通过cross_validate交叉验证过程的每个折叠处函数调用(也可以自己调用它)。在做任何事情之前,你应该调用基类fit()方法。

    请注意,该fit()方法返回self这允许使用表达式algo.fit(trainset).test(testset)

    3、trainset属性

    fit()返回基类方法后,您需要的有关当前训练集的所有信息(评级值等)都存储在self.trainset属性中。这是一个Trainset具有许多预测属性和方法的对象。

    为了说明它的用法,让我们制作一个算法来预测所有评级的平均值,用户的平均评分和项目的平均评级之间的平均值:

    def estimate(self, u, i):
    
            sum_means = self.trainset.global_mean
            div = 1
    
            if self.trainset.knows_user(u):
                sum_means += np.mean([r for (_, r) in self.trainset.ur[u]])
                div += 1
            if self.trainset.knows_item(i):
                sum_means += np.mean([r for (_, r) in self.trainset.ir[i]])
                div += 1
    
            return sum_means / div

    4、预测不可能

    由算法决定是否能够产生预测。如果预测不可能,则可以提出 PredictionImpossible异常。您需要先导入它:

    from surprise import PredictionImpossible

    该异常将被该predict()方法和估计r捕获^[R^ü一世将根据default_prediction()方法设置,可以覆盖。默认情况下,它返回列车集中所有评级的平均值。

    5、相似性和基线

    如果算法使用相似性度量或基线估计,您将需要接受bsl_optionssim_options作为__init__方法的参数 ,并将它们传递给Base类。

    class MyOwnAlgorithm(AlgoBase):
    
        def __init__(self, sim_options={}, bsl_options={}):
    
            AlgoBase.__init__(self, sim_options=sim_options,
                              bsl_options=bsl_options)
    
        def fit(self, trainset):
    
            AlgoBase.fit(self, trainset)
    
            # Compute baselines and similarities
            self.bu, self.bi = self.compute_baselines()
            self.sim = self.compute_similarities()
    
            return self
    
        def estimate(self, u, i):
    
            if not (self.trainset.knows_user(u) and self.trainset.knows_item(i)):
                raise PredictionImpossible('User and/or item is unkown.')
    
            # Compute similarities between u and v, where v describes all other
            # users that have also rated item i.
            neighbors = [(v, self.sim[u, v]) for (v, r) in self.trainset.ir[i]]
            # Sort these neighbors by similarity
            neighbors = sorted(neighbors, key=lambda x: x[1], reverse=True)
    
            print('The 3 nearest neighbors of user', str(u), 'are:')
            for v, sim_uv in neighbors[:3]:
                print('user {0:} with sim {1:1.2f}'.format(v, sim_uv))
    
            # ... Aaaaand return the baseline estimate anyway ;)
  • 相关阅读:
    线程同步的三种方式(Mutex,Event,Critical Section) 沧海
    VC++多线程下内存操作的优化 沧海
    C++内存对象大会战 沧海
    技术关注:搜索引擎经验 沧海
    jira 3.13.5版 安装 配置 用户权限控制 拂晓风起
    C++ int string 转换 拂晓风起
    C++调用C链接库会出现的问题 拂晓风起
    Windows Server 2003 IIS Service Unavailable 问题解决 拂晓风起
    研究 学术 开发 的好用工具(不包括常见的) 拂晓风起
    SGMarks 问世 (Firefox扩展Gmarks的扩展版) 纯属学习 拂晓风起
  • 原文地址:https://www.cnblogs.com/felixwang2/p/9391241.html
Copyright © 2011-2022 走看看