zoukankan      html  css  js  c++  java
  • 机器学习-神经网络

    一、

    import matplotlib.pyplot as plt
    from matplotlib.colors import ListedColormap
    from sklearn.neural_network import MLPClassifier
    from sklearn.datasets import load_wine
    from sklearn.model_selection import train_test_split
    import numpy as np
    import matplotlib.pyplot as plt

    line = np.linspace(-5,5,200)
    wine = load_wine()
    X = wine.data[:,:2]
    y = wine.target
    X_train, X_test, y_train, y_test = train_test_split(X,y,random_state=0)

    cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
    cmap_bold = ListedColormap(['#FF0000', '#00ff00', '#0000FF'])
    x_min, x_max = X_train[:, 0].min() -1, X_train[:, 0].max() + 1
    y_min, y_max = X_train[:, 1].min() -1, X_train[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, .02),
    np.arange(y_min, y_max, .02))

    mlp_20=MLPClassifier(solver='lbfgs', hidden_layer_sizes=[10])
    mlp_20.fit(X_train, y_train)
    Z1 = mlp_20.predict(np.c_[xx.ravel(), yy.ravel()])

    Z1 = Z1.reshape(xx.shape)
    plt.figure()
    plt.pcolormesh(xx, yy, Z1, cmap=cmap_light)
    plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', s=60)
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.show()

    二、使用MNIST数据集训练MLP神经网络

    from sklearn.datasets import fetch_mldata    #导入MNIST数据集获取工具
    from sklearn.model_selection import train_test_split  #数据集随机切分
    from sklearn.neural_network import MLPClassifier    #导入神经网络包

    mnist = fetch_mldata("MNIST original", data_home='MNIST_data/')    #使用工具获取MNIST数据集
    X = mnist.data/255.    #建立数据集并把特征向量的值全部除以255,这样全部数值就会在0和1之间
    y = mnist.target      #建立测试数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size= 5000, test_size=1000,random_state=62)   #选取5000个训练数据集,1000个作为测试数据集,随机数设置62
    mlp_hw = MLPClassifier(solver='lbfgs',hidden_layer_sizes=[100,100],activation='relu', alpha = 1e-5,random_state=62)  #设置100个隐藏层,使用lbfgs的优化器
    mlp_hw.fit(X_train,y_train)    #训练数据
    print('测试数据集得分: {:.2f}'.format(mlp_hw.score(X_test,y_test)*100))

  • 相关阅读:
    js---查找数组中的最大值(最小值),及相应的下标
    JS数组遍历的几种方法
    在 forEach 中使用 async/await 遇到的问题
    js 事件冒泡和事件捕获
    JS中dom0级事件和dom2级事件的区别介绍
    Vue集成Ueditor
    vue富文本编辑器 Vue-Quill-Editor
    Redis问题1---redis满了怎么办
    jQuery火箭图标返回顶部代码
    遇到的小问题
  • 原文地址:https://www.cnblogs.com/zhaop8078/p/9745253.html
Copyright © 2011-2022 走看看