zoukankan      html  css  js  c++  java
  • 多层感知机训练minist数据集

    MLP

    In [1]:
    %matplotlib inline
    import gluonbook as gb
    from mxnet.gluon import loss as gloss
    from mxnet import nd
    from mxnet import autograd
    
    In [2]:
    batch_size = 256
    train_iter, test_iter = gb.load_data_fashion_mnist(batch_size)
    
     

    模型参数初始化

    In [3]:
    num_inputs, num_out_puts, num_hiddens = 28*28, 10, 256
    W1 = nd.random.normal(scale=0.01,shape=(num_inputs,num_hiddens))
    b1 = nd.zeros(num_hiddens)
    W2 = nd.random.normal(scale=0.01,shape=(num_hiddens,num_out_puts))
    b2 = nd.zeros(num_out_puts)
    params = [W1,b1,W2,b2]
    
    for param in params:
        param.attach_grad()
    
     

    激活函数

    In [4]:
    def relu(X):
        return nd.maximum(X,0)
    
    In [5]:
    X = nd.array([[1,3,-1],[2,-2,-1]])
    relu(X)
    
    Out[5]:
    [[1. 3. 0.]
     [2. 0. 0.]]
    <NDArray 2x3 @cpu(0)>
     

    定义模型 H = relu(XW+b) O = HW + b

    In [6]:
    def net(X):
        X = X.reshape((-1, num_inputs))
        H = relu(nd.dot(X,W1) + b1)
        return nd.dot(H,W2) + b2
    
     

    softmax损失函数

    In [7]:
    loss = gloss.SoftmaxCrossEntropyLoss()
    
     

    调整参数

    In [9]:
    def sgd(params, lr, batch_size):
        for param in params:
            param[:] = param - lr * param.grad / batch_size
    
     

    是否预测中

    In [10]:
    def accuracy(y_hat,y):
        return (y_hat.argmax(axis=1)==y.astype('float32')).mean().asscalar()
    
     

    正确率

    In [11]:
    def evaluate_accuracy(data_iter,net):
        acc = 0
        for X,y in data_iter:
            acc+= accuracy(net(X),y)
        return acc / len(data_iter)
    
     

    训练模型

    In [12]:
    def train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,trainer=None):
        for epoch in range(num_epochs):
            train_l_sum = 0
            train_acc_sum = 0
            for X,y in train_iter:
                with autograd.record():
                    y_hat = net(X)
                    l = loss(y_hat,y)
                l.backward()
                if trainer is None:
                    sgd(params, lr , batch_size)
                else:
                    trainer.step(batch_size)
                train_l_sum += l.mean().asscalar()
                train_acc_sum += accuracy(y_hat,y)
            test_acc = evaluate_accuracy(test_iter,net)
            print('epoch %d, loss %.4f, train acc %.3f,test acc %.3f'
                  %(epoch+1, train_l_sum / len(train_iter),
                   train_acc_sum / len(train_iter),test_acc))
    
    num_epochs , lr = 5, 0.1
    train(net, train_iter,test_iter,loss,num_epochs,batch_size,params,lr)        
    
     
    epoch 1, loss 1.0423, train acc 0.640,test acc 0.745
    epoch 2, loss 0.6048, train acc 0.787,test acc 0.818
    epoch 3, loss 0.5297, train acc 0.814,test acc 0.833
    epoch 4, loss 0.4827, train acc 0.831,test acc 0.842
    epoch 5, loss 0.4626, train acc 0.837,test acc 0.846
    
    In [ ]:
     
  • 相关阅读:
    NLP
    Log Collect
    android 客户端 和 新浪微博如何打通的
    学术论文写作的 paper、code 资源
    学术论文写作的 paper、code 资源
    高观点下的高等数学(数学分析、线性代数)
    高观点下的高等数学(数学分析、线性代数)
    弦论 —— 宇宙的琴弦
    弦论 —— 宇宙的琴弦
    流体力学
  • 原文地址:https://www.cnblogs.com/TreeDream/p/10020964.html
Copyright © 2011-2022 走看看