zoukankan      html  css  js  c++  java
  • 使用lsh快速检索语义-词向量结合

    """
        test
    """
    import os
    import gensim
    import pickle
    import time
    import numpy as np
    
    DIR_PATH = os.path.dirname(os.path.abspath(__file__))
    HASHTABLES = os.path.join(DIR_PATH, 'resource', 'hashtables.pkl')
    WORD2VEC = os.path.join(DIR_PATH, 'resource', 'sgns.weibo.word')
    RESOURCES = os.path.join(DIR_PATH, 'resource', 'resources.pkl')
    
    
    
    class MyClass(object):
    
        def __init__(self, Table_num=5, Hashcode_fun=5):
            self.hashtables = HASHTABLES
            self.word2vec = WORD2VEC
            self.resources = RESOURCES
            self.table_num = Table_num
            self.Hashcode_fun = Hashcode_fun
    
        def load_traindata(self):
            model = gensim.models.KeyedVectors.load_word2vec_format(self.word2vec, unicode_errors='ignore')
            data = []
            features = []
    
            for word, vector in zip(model.vocab, model.vectors):
                features.append(vector)
                data.append(word)
                print(word)
            self.features = np.array(features)
            self.data = data
            with open(self.resources, 'wb') as fw:
                pickle.dump((self.features, self.data), fw)
            print('词向量序列化完毕,当前词向量数量:{}'.format(len(self.data)))
    
    
        def create_hashtables(self):
            with open(self.resources, 'rb') as fr:
                features, _ = pickle.load(fr)
            print('特征加载完毕,当前词向量数量:{}'.format(len(features)))
    
    
            users_size, items_size = features.shape
            hashtables = [[[] for _ in range(int('1' * self.Hashcode_fun) + 1)] for _ in range(self.table_num)]
    
            random_matrixes = [np.empty((self.Hashcode_fun, items_size)) for _ in range(self.table_num)]
            for i in range(self.table_num):
                random_matrixes[i] = np.random.uniform(-1, 1, (self.Hashcode_fun, items_size))
            for i, user_vec in enumerate(features):
    
                for j in range(self.table_num):
                    v = random_matrixes[j]
                    index = ''
                    for k in range(self.Hashcode_fun):
                        index += '1' if np.dot(user_vec, v[k]) > 0 else '0'
                    t_index = int(index, 2)
                    hashtables[j][t_index].append(i)
    
            with open(self.hashtables, 'wb') as fw:
                pickle.dump((hashtables,random_matrixes), fw)
            print('hash表存储完毕')
    
    
        def cal_similarity(self):
            with open(self.resources, 'rb') as fr:
                _, data = pickle.load(fr)
    
            with open(self.hashtables, 'rb') as fr:
                hashtables, random_matrixes = pickle.load(fr)
    
            model = gensim.models.KeyedVectors.load_word2vec_format(self.word2vec, unicode_errors='ignore')
            search_data = '中国'  # word2vec 找出的相似词:[('Portugal#', 0.8183228373527527), ('University#', 0.8141831755638123), ('Montfort', 0.8129391074180603),
    
            search_feature_vec = np.array(model.get_vector(search_data))
            sim = model.most_similar(search_data)
            print('word2vec 找出的相似词:{}'.format(sim))
            print('{}-莱雅,相似度:{}'.format(search_data, model.similarity(search_data, '莱雅')))
            print('{}-触网,相似度:{}'.format(search_data, model.similarity(search_data, '触网')))
    
    
            # '莱雅', '真材实料', '触网', '@Sophia', '汕尾',
            similar_users = set()
            t1 = time.time()
            for i, hashtable in enumerate(hashtables):
                index = ''
                for j in range(self.Hashcode_fun):
                    index += '1' if np.dot(search_feature_vec, random_matrixes[i][j]) > 0 else '0'
                target_index = int(index, 2)
                similar_users |= set(hashtable[target_index])
            t2 = time.time()
            print('查找相似性用户耗时:{:.4f}'.format(t2 - t1))
    
            t3 = time.time()
            res = {}
            for i in similar_users:
                res[data[i]] = cosine_similarity2(search_feature_vec, model.get_vector(data[i]))
            a = sorted(res.items(), key=lambda x: x[1], reverse=True)
            t4 = time.time()
            print('计算余弦相似度及排序耗时:{:.4f}ms'.format(t4-t3))
            print(a[:20])
    
    
    def cosine_similarity(x, y):
        res = np.array([[x[i] * y[i], x[i] * x[i], y[i] * y[i]] for i in range(len(x))])
        cos = sum(res[:, 0]) / (np.sqrt(sum(res[:, 1])) * np.sqrt(sum(res[:, 2])))
    
        return cos
    
    def cosine_similarity2(x,y):
        num = x.dot(y.T)
        denom = np.linalg.norm(x) * np.linalg.norm(y)
        return num / denom
    
    
    if __name__ == '__main__':
        ir = MyClass()
        # ir.load_traindata()
        # ir.create_hashtables()
        ir.cal_similarity()

    能够快速捕获一组相似性数据出来

  • 相关阅读:
    test
    flash链接需要后台调用时的插入flash方法
    js验证码倒计时
    设置Cookie
    用in判断input中的placeholder属性是否在这个对象里
    常用的正则表达式规则
    webApp添加到iOS桌面
    .substr()在字符串每个字母前面加上一个1
    PAT 甲级1001 A+B Format (20)(C++ -思路)
    PAT 1012 数字分类 (20)(代码+测试点)
  • 原文地址:https://www.cnblogs.com/demo-deng/p/12795872.html
Copyright © 2011-2022 走看看