zoukankan      html  css  js  c++  java
  • 跟 Google 学 machineLearning [1] -- hello sklearn

    时至今日,我才发现 machineLearning 的应用门槛已经被降到了这么低,简直唾手可得。我实在找不到任何理由不对它进入深入了解。如标题,感谢 Google 为这项技术发展作出的贡献。当然,可能其他人做了 99%, Google 只做了 1%,我想说,真是漂亮的 1%。

    切入正题,今天从 Youtube 上跟随 Google 的工程师完成了第一个 machineLearning 的小程序。作为学习这项技能的 hello world 吧。

    是为记录。

     1 from scipy.spatial import distance
     2 def euc(a,b):
     3     return distance.(a,b)
     4 
     5 class knnClassifier():
     6     def fit(self, x_train, y_train):
     7         self.x_train = x_train
     8         self.y_train = y_train
     9 
    10     def predict(self, x_test):
    11         predictions = []
    12         for row in x_test:
    13             label = self.closest(row)
    14             predictions.append(label)
    15         return predictions
    16 
    17     def closest(self, row):
    18         best_dist = euc(row, self.x_train[0])
    19         best_index = 0
    20         for i in range(1, len(self.x_train)):
    21             dist = euc(row, self.x_train[i])
    22             if dist < best_dist:
    23                 best_dist = dist
    24                 best_index = i
    25         return self.y_train[best_index]
    26 
    27 from sklearn import datasets
    28 iris = datasets.load_iris()
    29 x = iris.data
    30 y = iris.target
    31 
    32 from sklearn.cross_validation import train_test_split
    33 x_train, x_test, y_train, y_test = train_test_split(x, y, test_size= .5)
    34 print x_train
    35 print y_train
    36 
    37 my_classifier = knnClassifier()
    38 my_classifier.fit(x_train, y_train)
    39 predictions = my_classifier.predict(x_test)
    40 
    41 from sklearn.metrics import accuracy_score
    42 print accuracy_score(y_test, predictions)

    对上面的代码进行简单解释:

    1. 1-3 行是引用 scipy 的 distance 类中计算欧氏距离的函数,并进行了简单封装。(欧氏距离:N 维空间中,两个点之间的真实距离)

    2. 5-25 中,定义了自己的 classifier 类,关键方法包括了 fit 和 predict。fit 主要是将喂进来的数据赋值给内部变量;predict 是根据送进来的 row,返回我们预期的 Label。这里的 classifier 是我们 hand code 的,并不是训练出来的。事实上并不算是真正意义上的 machineLearning,但是很好的解释了其内部的原理。machineLearning 中,我们定义的 closet 函数,将通过训练的到,即 model

    3. 27-30, 在入了 sklearn 库中的 iris 花的数据库,作为我们后面实验的数据来源。iris_data 是三种花的原始数据,是一个三维数组。数组中每个元素代表一朵花的三个参数,分别是花的xx长度,花的xx宽度,和xx长度(我并不关系他是什么数据,反正是花的数据);iris_target 是 data 相对应的花的种类,大概就是0表示红玫瑰,1表示蓝玫瑰,2表示粉玫瑰之类。

    4. 32-35, 把载入的花朵数据 split 为两组,一组用做 train,作为预测的凭据,另一组作为检验 classifier 准确性的待测数据。验证时,因为验证组的数据对应的结果也是已知的,所以拿 classifier 出来的结果与真实值比较,便可知 classifier 是否合理。使用上面代码进行判定的成功率已经达到 >90%,事实上拿它来对未知新数据判定,结果可信度已经很高。

    5. 37-39 ,应用了在 2 中定义的 classifier,将 4 中分割出来的 x_train, y_train 喂给 classifier。然后,使用 classifier 根据 x_test 中的花的数据,预测花的种类,得到对应的预测结果数组 predictions。

    6. 41-42,比较真实的花的种类 y_test 与 预测结果 predictions 之间的符合度。可以看到并不是 100%,信息总是会有遗漏的,哪怕是人眼来判断也一样。

    因为载入的数据在 split 时,是随机的。所以,因为 train 组和 test 组数据的不同,预测的准确度也会稍有不同。

    虽然这里的 classifier 已经有了很高的准确度,但是,不能回避的是,这样的计算比对,运算量是非常大的。同时,因为我们数据属性的关系,我们可以直接通过找最接近数据来进行预测,在其他一些应用中,某些属性并不是线性分布的,或者,并不是凭人眼能发现规律的。这时候,就需要真正的 train 了。

  • 相关阅读:
    【转】centos7升级git版本
    小程序购物车抛物线动画(通用)
    IDEA高级操作
    JAVA获取各种路径
    这些SpringBoot天生自带Buff工具类你都用过哪些?
    搜狗输入法简繁问题
    Java8 Stream流递归,几行代码搞定遍历树形结构
    SpringBoot 启动时实现自动执行代码的几种方式讲解
    公司用的 MySQL 团队开发规范,非常详细,建议收藏!
    Springboot整合websocket全面解析
  • 原文地址:https://www.cnblogs.com/pied/p/8092727.html
Copyright © 2011-2022 走看看