zoukankan      html  css  js  c++  java
  • 机器学习-神经网络

    一、

    import matplotlib.pyplot as plt
    from matplotlib.colors import ListedColormap
    from sklearn.neural_network import MLPClassifier
    from sklearn.datasets import load_wine
    from sklearn.model_selection import train_test_split
    import numpy as np
    import matplotlib.pyplot as plt

    line = np.linspace(-5,5,200)
    wine = load_wine()
    X = wine.data[:,:2]
    y = wine.target
    X_train, X_test, y_train, y_test = train_test_split(X,y,random_state=0)

    cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
    cmap_bold = ListedColormap(['#FF0000', '#00ff00', '#0000FF'])
    x_min, x_max = X_train[:, 0].min() -1, X_train[:, 0].max() + 1
    y_min, y_max = X_train[:, 1].min() -1, X_train[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, .02),
    np.arange(y_min, y_max, .02))

    mlp_20=MLPClassifier(solver='lbfgs', hidden_layer_sizes=[10])
    mlp_20.fit(X_train, y_train)
    Z1 = mlp_20.predict(np.c_[xx.ravel(), yy.ravel()])

    Z1 = Z1.reshape(xx.shape)
    plt.figure()
    plt.pcolormesh(xx, yy, Z1, cmap=cmap_light)
    plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', s=60)
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.show()

    二、使用MNIST数据集训练MLP神经网络

    from sklearn.datasets import fetch_mldata    #导入MNIST数据集获取工具
    from sklearn.model_selection import train_test_split  #数据集随机切分
    from sklearn.neural_network import MLPClassifier    #导入神经网络包

    mnist = fetch_mldata("MNIST original", data_home='MNIST_data/')    #使用工具获取MNIST数据集
    X = mnist.data/255.    #建立数据集并把特征向量的值全部除以255,这样全部数值就会在0和1之间
    y = mnist.target      #建立测试数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size= 5000, test_size=1000,random_state=62)   #选取5000个训练数据集,1000个作为测试数据集,随机数设置62
    mlp_hw = MLPClassifier(solver='lbfgs',hidden_layer_sizes=[100,100],activation='relu', alpha = 1e-5,random_state=62)  #设置100个隐藏层,使用lbfgs的优化器
    mlp_hw.fit(X_train,y_train)    #训练数据
    print('测试数据集得分: {:.2f}'.format(mlp_hw.score(X_test,y_test)*100))

  • 相关阅读:
    【原】一张图片优化5K的带宽成本
    让手机站点像原生应用的四大途径
    iScroll4下表单元素聚焦及键盘的异常问题
    蜕变·WebRebuild 2013 前端年度交流会邀请
    【原】js实现复制到剪贴板功能,兼容所有浏览器
    【原】css实现两端对齐的3种方法
    【原】常见CSS3属性对ios&android&winphone的支持
    一枚前端开发-页面重构方向的招聘信息
    【原】分享超实用工具给大家
    【原】webapp开发中兼容Android4.0以下版本的css hack
  • 原文地址:https://www.cnblogs.com/zhaop8078/p/9745253.html
Copyright © 2011-2022 走看看