zoukankan      html  css  js  c++  java
  • 机器学习学习中-->>>knn手动实现

    knn学习中,手动实现测试版   knn思想

      1 # 假设test_data1 为类型1  test_data2 为类型2 test_data3 为类型3
      2 test_data1 = [(1, 2), (2, 5), (3, 3), (5, 9), (6, 8), (8, 3), (4, 3)]
      3 test_data2 = [(15, 6), (2, 6), (8, 6), (3, 1), (4, 5), (2, 1), (3, 6)]
      4 test_data3 = [(25, 4), (3, 2), (8, 4), (2, 53), (6, 18), (13, 3), (25, 8)]
      5 
      6 
      7 def knn(k:int, new_data: tuple, *test_data:list):
      8     """
      9     实现knn分类算法
     10     :param k: knn中 k的值
     11     :param new_data: 新纪录
     12     :return: 返回所属类型
     13     """
     14     # 用于生成字典的 key, 其中
     15     # i = 1
     16     j = 1
     17     # 用于存储所有的字典
     18     all_distance_dict = {}
     19     for data in test_data:
     20         distance = computed_range(new_data = new_data, test_data = data)
     21         if j == 1  :
     22             temp = [("j" + str(a)) for a in range(len(distance))]
     23         elif j == 2:
     24             temp = [("k" + str(a)) for a in range(len(distance))]
     25         elif j == 3:
     26             temp = [("l" + str(a)) for a in range(len(distance))]
     27 
     28         distance_dict = dict(zip(temp, distance))
     29         all_distance_dict.update(distance_dict)
     30         # i += 1
     31         j += 1
     32 
     33 
     34     # 排序后的结果
     35     sort_all_distance_list = sorted(all_distance_dict.items(), key=lambda x: x[1], reverse=False)
     36     sort_all_distance_dict = dict(sort_all_distance_list)
     37 
     38     print(sort_all_distance_dict)
     39 
     40     # 排序后取前k个
     41     end_distance_list =  []
     42     # 引入一个计数器
     43     i = 0
     44     for key in sort_all_distance_dict.keys():
     45         if i != k:
     46             end_distance_list.append(key)
     47         else :
     48             break
     49         i += 1
     50 
     51     # 创建计数器  res1 res2 res3 分别表示每个种类的个数
     52     res1, res2 ,res3 = 0, 0, 0
     53     for key in end_distance_list:
     54         if key[0] == "j":
     55             res1 += 1
     56         elif key[0] =="k" :
     57             res2 +=1
     58         elif key[0] == "l":
     59             res3 += 1
     60 
     61     if res1>res2:
     62         if res1>res3:
     63             return "类型一"
     64         elif res3 > res2 :
     65             return "类型三"
     66     else :
     67         return "类型二"
     68 
     69 
     70 
     71 
     72 # 计算距离
     73 def computed_range(new_data: tuple, test_data:list, formula_mode = 1)->list :
     74     '''
     75     该函数用于计算欧氏距离
     76     :param new_data:新需要计算的数据
     77     :param test_data:为样本数据
     78     :param formula_mode:用于选择相似度计算方式:其中 1:欧氏距离 2:曼哈顿距离 3:余弦相似度
     79     :return:list 代表每个数据与新记录之间的距离
     80     '''
     81     result = []  # 定义一个列表:用于存出结果
     82     if formula_mode == 1 :
     83         formula = Euclidean_distance
     84     elif formula_mode == 2 :
     85         formula = Manhattan_distance
     86     elif formula_mode == 3:
     87         formula = cosine_measure
     88     else:
     89         return  # 输入非 1 2 3 则直接返回None
     90 
     91     for data in test_data :
     92         # 用于计算距离欧式距离
     93         length = formula(data, new_data)
     94         result.append(length)
     95 
     96     return result
     97 
     98 # 用于计算欧氏距离
     99 def Euclidean_distance(data1,data2) :
    100     return (((data1[0] - data2[0]) ** 2) + ((data1[1] - data2[1]) ** 2)) ** (0.5)
    101 
    102 # 用于计算曼哈顿距离
    103 def Manhattan_distance() :
    104     pass
    105 
    106 # 用于计算余弦相似度
    107 def cosine_measure() :
    108     pass
    109 
    110 string = knn( 3, (5,55), test_data1, test_data2, test_data3)
    111 print(string)

    返回结果

    {'l3': 3.605551275463989, 'l4': 37.013511046643494, 'j3': 46.0, 'j4': 47.01063709417264, 'k6': 49.040799340956916, 'k1': 49.09175083453431, 'k2': 49.09175083453431, 'k0': 50.00999900019995, 'k4': 50.00999900019995, 'j1': 50.08991914547278, 'l6': 51.07837115648854, 'l2': 51.088159097779204, 'j6': 52.009614495783374, 'j2': 52.03844732503075, 'j5': 52.08646657242167, 'l5': 52.61178575186362, 'l1': 53.03772242470448, 'j0': 53.150729063673246, 'k3': 54.037024344425184, 'k5': 54.08326913195984, 'l0': 54.78138369920935}
    类型三
  • 相关阅读:
    用mstsc連接服務器收到超出最大連接數的問題
    DIV+CSS自適就高度解決方案
    与佛关于婚外情的经典对白[轉]
    [轉]修改计算机名,sharepoint站点打不开了
    今天測試服務器壞了
    Microsoft.Jet.Oledb.4.0 找不到提供者或未安裝問題
    今天開始學習silverlight了
    FreeTextBox運行錯誤解決
    使用AJAXControlToolkit出現的問題
    子类如何调用被子类override了的方法?
  • 原文地址:https://www.cnblogs.com/luminous-Xin/p/14720178.html
Copyright © 2011-2022 走看看