zoukankan      html  css  js  c++  java
  • KNN cosine 余弦相似度计算

    # coding: utf-8
    import collections
    import numpy as np
    import os
    from sklearn.neighbors import NearestNeighbors
    
    
    def cos(vector1,vector2):
        dot_product = 0.0;
        normA = 0.0;
        normB = 0.0;
        for a,b in zip(vector1,vector2):
            dot_product += a*b
            normA += a**2
            normB += b**2
        if normA == 0.0 or normB==0.0:
            return None
        else:
            return dot_product / ((normA*normB)**0.5)
    
    
    def iterbrowse(path):
        for home, dirs, files in os.walk(path):
            for filename in files:
                yield os.path.join(home, filename)
    
    
    def get_data(filename):
        white_verify = []
        with open(filename) as f:
            lines = f.readlines()
            for line in lines:
                a = line.split("	")
                if len(a) != 78:
                    print(line)
                    raise Exception("fuck")
                white_verify.append([float(n) for n in a[3:]])
        return white_verify
    
    unwanted_features = {6, 7, 8, 41,42,43,67,68,69,70,71,72,73,74,75}
    
    def get_wanted_data(x):
        return x
        """
        ans = []
        for item in x:
            #row = [data for i, data in enumerate(item) if i+6 in wanted_feature]
            row = [data for i, data in enumerate(item) if i+6 not in unwanted_features]
            ans.append(row)
            #assert len(row) == len(wanted_feature)
            assert len(row) == len(x[0])-len(unwanted_features)
        return ans
        """
    
    
    if __name__ == "__main__":
        neg_file = "cc_data/black/black_all.txt"
        pos_file = "cc_data/white/white_all.txt"
        X = []
        y = []
        # if os.path.isfile(pos_file):
        #     if pos_file.endswith('.txt'):
        #         pos_set = np.genfromtxt(pos_file)
        #     elif pos_file.endswith('.npy'):
        #         pos_set = np.load(pos_file)
        #     X.extend(pos_set)
        #     y += [0] * len(pos_set)
        # print("len of X(white):", len(X))
        if os.path.isfile(neg_file):
            if neg_file.endswith('.txt'):
                neg_set = np.genfromtxt(neg_file)
            elif neg_file.endswith('.npy'):
                neg_set = np.load(neg_file)
            X.extend(list(neg_set) * 1)
            y += [1] * (1 * len(neg_set))
        print("len of X:", len(X))
        # print("X sample:", X[:3])
        # print("len of y:", len(y))
        # print("y sample:", y[:3])
        X = [x[3:] for x in X]
        X = get_wanted_data(X)
        # print("filtered X sample:", X[:3])
    
        black_verify = []
        for f in iterbrowse("todo/top"):
            print(f)
            black_verify += get_data(f)
        # print(black_verify)
        black_verify = get_wanted_data(black_verify)
        black_verify_labels = [1] * len(black_verify)
    
        white_verify = get_data("todo/white_verify.txt")
        # print(white_verify)
        white_verify = get_wanted_data(white_verify)
        white_verify_labels = [0] * len(white_verify)
    
        unknown_verify = get_data("todo/pek_feature74.txt")
        unknown_verify = get_wanted_data(unknown_verify)
    
        bd_verify = get_data("guzhaoshen_pek_out.txt")
        # print(unknown_verify)
    
        # samples = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]]
        #neigh = NearestNeighbors(n_neighbors=3)
        neigh = NearestNeighbors(n_neighbors=1, metric='cosine')
        neigh.fit(X)
    
        print("neigh.kneighbors(black_verify)")
        nearest_points = (neigh.kneighbors(black_verify))
        print(nearest_points)
        for i, x in enumerate(black_verify):
            print(i, nearest_points[1][i], "cosine:", cos(x, nearest_points[1][i]))
    
        #print(neigh.predict(black_verify))
        print("neigh.kneighbors(white_verify)")
        nearest_points = (neigh.kneighbors(white_verify))
        print(nearest_points)
        for i, x in enumerate(white_verify):
            print(i, nearest_points[1][i], "cosine:", cos(x, nearest_points[1][i]))
    
        #print(neigh.predict(white_verify))
        print("neigh.kneighbors(unknown_verify)")
        nearest_points = (neigh.kneighbors(unknown_verify))
        print(nearest_points)
        for i, x in enumerate(unknown_verify):
            print(i, nearest_points[1][i], "cosine:", cos(x, nearest_points[1][i]))
    
        #print(neigh.predict(unknown_verify))
        print("neigh.kneighbors(self)")
        print(neigh.kneighbors(X[:3]))
    
        #print(neigh.predict(X[:3]))
        print("neigh.kneighbors(bd pek)")
        print(neigh.kneighbors(bd_verify))
    
        nearest_points = (neigh.kneighbors(bd_verify))
        print(nearest_points)
        for i, x in enumerate(bd_verify):
            print(i, nearest_points[1][i], "cosine:", cos(x, nearest_points[1][i]))
    

     输出示例:

    neigh.kneighbors(white_verify)
    (array([[ 0.01140831],
           [ 0.0067373 ],
           [ 0.00198682],
           [ 0.00686728],
           [ 0.00210445],
           [ 0.00061413],
           [ 0.00453888]]), array([[11032],
           [  967],
           [11091],
           [13149],
           [11091],
           [19041],
           [13068]]))
    (0, array([11032]), 'cosine:', 1.0)
    (1, array([967]), 'cosine:', 1.0)
    (2, array([11091]), 'cosine:', 1.0)
    (3, array([13149]), 'cosine:', 1.0)
    (4, array([11091]), 'cosine:', 1.0)
    (5, array([19041]), 'cosine:', 1.0)
    (6, array([13068]), 'cosine:', 1.0)

    样本质量堪忧啊!!!

    注意:如果是常规knn,计算距离时候记得标准化。如果各个维度的数据属性衡量单位不一样:

        from sklearn import preprocessing
        scaler = preprocessing.StandardScaler().fit(X)
        X = scaler.transform(X)
        print("standard X sample:", X[:3])
    
        black_verify = scaler.transform(black_verify)
        print(black_verify)
    
        white_verify = scaler.transform(white_verify)
        print(white_verify)
    
        unknown_verify = scaler.transform(unknown_verify)
        print(unknown_verify)
    
  • 相关阅读:
    php 基础------数组过滤
    js或者jq 使用cookie 时在谷歌浏览器不好使
    css3 -阻止元素成为鼠标事件目标 pointer-events
    CSS3-----transform 转换
    css3---过渡
    css3动画----animation
    移动端尺寸适配--媒体查询
    工作一年总结
    关于Jquery.Data()和HTML标签的data-*属性
    android shape
  • 原文地址:https://www.cnblogs.com/bonelee/p/9112077.html
Copyright © 2011-2022 走看看