zoukankan      html  css  js  c++  java
  • KNN 约会网站配对

    # coding=utf-8
    # kNN-约会网站约友分类
    from numpy import *
    import matplotlib.pyplot as plt
    import matplotlib.font_manager as font
    import operator
     
     
    # 【1】获取数据
    def init_data():
        # 打开训练集文件
        f = open(r"datingTestSet2.txt", "r")
        rows = f.readlines()
        lines_number = len(rows)
        return_mat = zeros((lines_number, 3))  # lines_number行 3列
        class_label_vec = []
        index = 0
        for row in [value.split("	") for value in rows]:
            return_mat[index, :] = row[0:3]  # 取row前三列
            class_label_vec.append(int(row[-1]))  # row[-1]取列表最后一列数据
            index += 1
        # 关闭打开的文件
        f.close()
        return return_mat, class_label_vec
     
     
    # 【2】特征缩放 X:=[X-mean(X)]/std(X) || X:=[X-min(X)]/max(X)-min(X) ;
    def feature_scaling(data_set):
        # 特征缩放参数
        max_value = data_set.max(0)
        min_value = data_set.min(0)
        # avg_value = (min_value + max_value)/2
        diff_value = max_value - min_value
        norm_data_set = zeros(shape(data_set))  # 初始化与data_set结构一样的零array
        # print(norm_data_set)
        m = data_set.shape[0]
        norm_data_set = data_set - tile(min_value, (m, 1))  # avg_value
        norm_data_set = norm_data_set/tile(diff_value, (m, 1))
        return norm_data_set, diff_value, min_value
     
     
    # 【3】kNN实现 input_set:输入集 data_set:训练集
    def classify0(input_set, data_set, labels, k):
        data_set_size = data_set.shape[0]
        # 计算距离tile 重复以input_set生成跟data_set一样行数的mat
        diff_mat = tile(input_set, (data_set_size, 1)) - data_set
        sq_diff_mat = diff_mat ** 2
        sq_distances = sq_diff_mat.sum(axis=1)
        distances = sq_distances ** 0.5
        # 按照距离递增排序
        sorted_dist_indicies = distances.argsort()  # argsort返回从小到大排序的索引值
        class_count = {}  # 初始化一个空字典
        # 选取距离最小的k个点
        for i in range(k):
            vote_ilabel = labels[sorted_dist_indicies[i]]
            # 确认前k个点所在类别的出现概率,统计几个类别出现次数
            class_count[vote_ilabel] = class_count.get(vote_ilabel, 0) + 1
        # 返回前k个点出现频率最高的类别作为预测分类
        sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
        return sorted_class_count[0][0]
     
     
    # 【4】测试kNN
    def dating_class_test():
        # 测试样本比例
        ho_ratio = 0.1
        dating_data_mat, dating_labels = init_data()
        norm_mat, diff_dt, min_value = feature_scaling(dating_data_mat)
        m = norm_mat.shape[0]
        num_test_vecs = int(m * ho_ratio)  # 测试样本的数量
        error_count = 0.0
        for i in range(num_test_vecs):
            # 测试样本和训练样本
            classifier_result = classify0(norm_mat[i, :], norm_mat[num_test_vecs:m, :],
                                          dating_labels[num_test_vecs:m], 4)
            print("the classifier came back with:%d , the real answer is:%d" % (classifier_result, dating_labels[i]))
            if classifier_result != dating_labels[i]:
                error_count += 1.0
        right_ratio = 1-error_count/float(num_test_vecs)
        print("the total right rate is :%f %%" % (right_ratio*100))
     
     
    # 【5】样本数据绘图
    def make_plot():
        # 获取数据
        x, y = init_data()
        # 特征缩放
        norm_mat, diff_dt, min_value = feature_scaling(x)
     
        fig = plt.figure()
        ax = fig.add_subplot(111)  # 画布分割一行一列数据在第一块
        # 设置字体
        simsun = font.FontProperties(fname='C:WindowsFontssimsun.ttc')
        # ax.scatter(x[:, 1], x[:, 2], 15.0*array(y), 15.0*array(y))  # 取2 3列绘图
        # plt.xlabel("玩视频耗时百分比", fontproperties=simsun)
        # plt.ylabel("周消耗冰激凌公升数", fontproperties=simsun)
     
        ax.scatter(norm_mat[:, 0], norm_mat[:, 1], 15.0*array(y), 15.0*array(y))  # 取1 2列绘图
        plt.xlabel("飞行常客里程数", fontproperties=simsun)
        plt.ylabel("玩视频耗时百分比", fontproperties=simsun)
        plt.show()
     
     
    # 预测函数
    def classify_main():
        result_list = ['not at all', 'in small doses', 'in large doses']
        # 输入
        ff_miles = float(input("frequent flier miles earned per year?"))
        percent_tats = float(input("percentage of time spent playing video games?"))
        ice_cream = float(input("liters of ice cream consumed per year?"))
        # 获取数据
        dating_data_mat, dating_labels = init_data()
        # 特征缩放
        norm_mat, diff_dt, min_value = feature_scaling(dating_data_mat)
        in_arr = array([ff_miles, percent_tats, ice_cream])
        # 计算距离
        classifier_result = classify0((in_arr-min_value)/diff_dt, norm_mat, dating_labels, 3)
        print("You will probably like this person:", result_list[classifier_result-1])
     
    # 主方法
    if __name__ == "__main__":
     
        # 绘图
        make_plot()
        # 测试kNN脚本
        # dating_class_test()
        # 预测函数
        classify_main()
  • 相关阅读:
    Apache 配置多站点访问「为项目分配二级域名」
    php封装的mysqli类完整实例
    PHP实现链式操作的三种方法详解
    php实现简单链式操作mysql数据库类
    PHP PDO_MYSQL 链式操作 非链式操作类
    23个数据库常用查询语句
    微信小程序表单弹窗实例
    ES6 && ECMAScript2015 新特性
    ES6新语法概览
    sql将两个日期之间的日子全列出来
  • 原文地址:https://www.cnblogs.com/roscangjie/p/10802172.html
Copyright © 2011-2022 走看看