zoukankan      html  css  js  c++  java
  • 多层感知机训练minist数据集

    MLP

    In [1]:
    %matplotlib inline
    import gluonbook as gb
    from mxnet.gluon import loss as gloss
    from mxnet import nd
    from mxnet import autograd
    
    In [2]:
    batch_size = 256
    train_iter, test_iter = gb.load_data_fashion_mnist(batch_size)
    
     

    模型参数初始化

    In [3]:
    num_inputs, num_out_puts, num_hiddens = 28*28, 10, 256
    W1 = nd.random.normal(scale=0.01,shape=(num_inputs,num_hiddens))
    b1 = nd.zeros(num_hiddens)
    W2 = nd.random.normal(scale=0.01,shape=(num_hiddens,num_out_puts))
    b2 = nd.zeros(num_out_puts)
    params = [W1,b1,W2,b2]
    
    for param in params:
        param.attach_grad()
    
     

    激活函数

    In [4]:
    def relu(X):
        return nd.maximum(X,0)
    
    In [5]:
    X = nd.array([[1,3,-1],[2,-2,-1]])
    relu(X)
    
    Out[5]:
    [[1. 3. 0.]
     [2. 0. 0.]]
    <NDArray 2x3 @cpu(0)>
     

    定义模型 H = relu(XW+b) O = HW + b

    In [6]:
    def net(X):
        X = X.reshape((-1, num_inputs))
        H = relu(nd.dot(X,W1) + b1)
        return nd.dot(H,W2) + b2
    
     

    softmax损失函数

    In [7]:
    loss = gloss.SoftmaxCrossEntropyLoss()
    
     

    调整参数

    In [9]:
    def sgd(params, lr, batch_size):
        for param in params:
            param[:] = param - lr * param.grad / batch_size
    
     

    是否预测中

    In [10]:
    def accuracy(y_hat,y):
        return (y_hat.argmax(axis=1)==y.astype('float32')).mean().asscalar()
    
     

    正确率

    In [11]:
    def evaluate_accuracy(data_iter,net):
        acc = 0
        for X,y in data_iter:
            acc+= accuracy(net(X),y)
        return acc / len(data_iter)
    
     

    训练模型

    In [12]:
    def train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,trainer=None):
        for epoch in range(num_epochs):
            train_l_sum = 0
            train_acc_sum = 0
            for X,y in train_iter:
                with autograd.record():
                    y_hat = net(X)
                    l = loss(y_hat,y)
                l.backward()
                if trainer is None:
                    sgd(params, lr , batch_size)
                else:
                    trainer.step(batch_size)
                train_l_sum += l.mean().asscalar()
                train_acc_sum += accuracy(y_hat,y)
            test_acc = evaluate_accuracy(test_iter,net)
            print('epoch %d, loss %.4f, train acc %.3f,test acc %.3f'
                  %(epoch+1, train_l_sum / len(train_iter),
                   train_acc_sum / len(train_iter),test_acc))
    
    num_epochs , lr = 5, 0.1
    train(net, train_iter,test_iter,loss,num_epochs,batch_size,params,lr)        
    
     
    epoch 1, loss 1.0423, train acc 0.640,test acc 0.745
    epoch 2, loss 0.6048, train acc 0.787,test acc 0.818
    epoch 3, loss 0.5297, train acc 0.814,test acc 0.833
    epoch 4, loss 0.4827, train acc 0.831,test acc 0.842
    epoch 5, loss 0.4626, train acc 0.837,test acc 0.846
    
    In [ ]:
     
  • 相关阅读:
    Linux 文件的软连接和硬连接
    URLOS发布NFS文件加速功能,可有效提升NFS小文件读取性能
    Vue底层学习3——手撸发布订阅模式
    Vue底层学习2——手撸数据响应化
    Vue底层学习1——原理解析
    rest api测试工具frisbyjs
    git ignore 微软临时文件(~$xxx.xlsx)
    数据虚拟化-基础概念
    elasticsearch移除映射类型(mapping type)
    activemq Virtual Destinations 虚拟目的地
  • 原文地址:https://www.cnblogs.com/TreeDream/p/10020964.html
Copyright © 2011-2022 走看看