zoukankan      html  css  js  c++  java
  • 02-03 感知机对偶形式(鸢尾花分类)


    更新、更全的《机器学习》的更新网站,更有python、go、数据结构与算法、爬虫、人工智能教学等着你:https://www.cnblogs.com/nickchen121/p/11686958.html

    感知机对偶形式(鸢尾花分类)

    一、导入模块

    from matplotlib.font_manager import FontProperties
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import random
    %matplotlib inline
    font = FontProperties(fname='/Library/Fonts/Heiti.ttc')
    

    二、获取数据

    def get_data():
        df = pd.read_csv(
            'http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data', header=None)
        X = df.iloc[0:100, [0, 2]].values
        train_data_p = df.iloc[0:50, [0, 2, 4]].values
        train_data_n = df.iloc[50:100, [0, 2, 4]].values
        train_data_p[:, [2]], train_data_n[:, [2]] = -1, 1
        train_data = train_data_p.tolist() + train_data_n.tolist()
    
        return train_data, X
    

    三、训练模型

    def train(num_iter, train_data, learning_rate):
        w = 0.0
        b = 0
        data_length = len(train_data)
        alpha = [0 for _ in range(data_length)]
        train_data = np.array(train_data)
        gram = np.matmul(train_data[:, 0:-1], train_data[:, 0:-1].T)
        for i in range(num_iter):
            count = 0
            i = random.randint(0, data_length - 1)
            yi = train_data[i, -1]
            for j in range(data_length):
                count += alpha[j] * train_data[j, -1] * gram[i, j]
            count += b
            if (yi * count <= 0):
                alpha[i] = alpha[i] + learning_rate
                b = b + learning_rate * yi
        for i in range(data_length):
            w += alpha[i] * train_data[i, 0:-1] * train_data[i, -1]
        return w, b, alpha, gram
    

    四、可视化

    def plot_points(w, b, X):
        plt.figure()
        x1 = np.linspace(4, 7, 100)
        x2 = (-b - w[0] * x1) / (w[1] + 1e-10)
        plt.plot(x1, x2, color='k')
        plt.scatter(X[:50, 0], X[:50, 1], color='r', s=50, marker='o', label='山鸢尾')
        plt.scatter(X[50:100, 0], X[50:100, 1], color='b',
                    s=50, marker='x', label='变色鸢尾')
        plt.xlabel('萼片长度(cm)', fontproperties=font)
        plt.ylabel('花瓣长度(cm)', fontproperties=font)
        plt.legend(prop=font)
        plt.show()
    

    五、运行

    train_data, X = get_data()
    w, b, alpha, gram = train(
        num_iter=1000, train_data=train_data, learning_rate=0.1)
    plot_points(w, b, X)
    

    png

  • 相关阅读:
    MFC中CDialog与其对话框资源的绑定 dll中资源的切换
    DirectDraw 显示 YUV
    ClipCursor与GetClipCursor 用法
    MFC消息处理流程概述 .
    HTML5 WebSocket 技术介绍
    NodepartySZ1 深圳聚会活动回顾总结[2012.01.08] CNode
    index QuickWeb文档
    Node.js Manual
    An innovative way to replace AJAX and JSONP using node.js and socket.io
    RequireJS
  • 原文地址:https://www.cnblogs.com/nickchen121/p/11686753.html
Copyright © 2011-2022 走看看