zoukankan      html  css  js  c++  java
  • Softmax实现 fashion.mnist 分类

    softmax

    #!/usr/bin/env python
    # coding: utf-8
    
    # In[1]:
    
    
    get_ipython().run_line_magic('matplotlib', 'inline')
    import gluonbook as gb
    from mxnet import autograd,nd
    
    
    # In[2]:
    
    
    batch_size = 256
    train_iter,test_iter = gb.load_data_fashion_mnist(batch_size)
    
    
    # In[3]:
    
    
    num_inputs = 784
    num_outputs = 10
    
    W = nd.random.normal(scale=0.01,shape=(num_inputs,num_outputs))
    b = nd.zeros(num_outputs)
    
    
    # In[4]:
    
    
    W.attach_grad()
    b.attach_grad()
    
    
    # softmax运算
    
    # In[5]:
    
    
    X = nd.array([[1,2,3],[4,5,6]])
    X.sum(axis=0,keepdims=True)
    
    
    # In[6]:
    
    
    def softmax(X):
        X_exp = X.exp()
        partition = X_exp.sum(axis = 1,keepdims = True)
        return X_exp / partition
    
    
    # 例如
    
    # In[7]:
    
    
    X = nd.random.normal(shape=(2,5))
    X_prob = softmax(X)
    X_prob,X_prob.sum(axis=1)
    
    
    # 定义模型
    
    # In[8]:
    
    
    def net(X):
        return softmax(nd.dot(X.reshape((-1,num_inputs)),W)+b)
    
    
    # 定义损失函数
    
    # In[9]:
    
    
    y_hat = nd.array([[0.1,0.3,0.6],[0.3,0.2,0.5]])
    y = nd.array([0,2])
    nd.pick(y_hat,y)
    
    
    # 交叉熵损失函数
    
    # In[10]:
    
    
    def cross_entropy(y_hat,y):
        return - nd.pick(y_hat,y).log()
    
    
    # 计算分类准确率
    
    # In[11]:
    
    
    def accuracy(y_hat,y):
        return (y_hat.argmax(axis=1)==y.astype('float32')).mean().asscalar()
    
    
    # In[12]:
    
    
    accuracy(y_hat,y)
    
    
    # 评价 net 在 data_iter上的准确率
    
    # In[13]:
    
    
    def evaluate_accuracy(data_iter,net):
        acc = 0
        for X,y in data_iter:
            acc += accuracy(net(X),y)
        return acc / len(data_iter)
    
    
    # In[14]:
    
    
    evaluate_accuracy(test_iter,net)
    
    
    # 训练模型
    
    # In[15]:
    
    
    num_epochs, lr = 5, 0.1
    
    def train(net, train_iter, test_iter, loss, num_epochs, batch_size,
                  params=None, lr=None, trainer=None):
        for epoch in range(num_epochs):
            train_l_sum = 0
            train_acc_sum = 0
            for X, y in train_iter:     # 对每个样本预测
                with autograd.record():
                    y_hat = net(X)      # 计算预测值 XW+b
                    l = loss(y_hat, y)  # 计算交叉熵函数
                l.backward()            # 交叉熵函数求导
    
                gb.sgd(params, lr, batch_size)   # 修改参数 W,b
    
                train_l_sum += l.mean().asscalar()
                train_acc_sum += accuracy(y_hat, y)
            test_acc = evaluate_accuracy(test_iter, net)
            print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
                  % (epoch + 1, train_l_sum / len(train_iter),
                     train_acc_sum / len(train_iter), test_acc))
    
    train(net, train_iter, test_iter, cross_entropy, num_epochs,batch_size, [W, b], lr)
    
    
    # 预测
    
    # In[16]:
    
    
    for X, y in test_iter:
        break
    
    true_labels = gb.get_fashion_mnist_labels(y.asnumpy())
    pred_labels = gb.get_fashion_mnist_labels(net(X).argmax(axis=1).asnumpy())
    titles = [true + '
    ' + pred for true, pred in zip(true_labels, pred_labels)]
    
    gb.show_fashion_mnist(X[0:9], titles[0:9])
    View Code

  • 相关阅读:
    为何在JDK安装路径下存在两个JRE?
    awk中printf的使用说明
    awk中printf的使用说明
    awk中printf的使用说明
    修改SecureCRT终端的Home和End功能键。
    修改SecureCRT终端的Home和End功能键。
    解决mysqldb查询大量数据导致内存使用过高的问题
    Linux 硬盘工具之hdparm
    Linux 硬盘工具之hdparm
    iostat命令详解
  • 原文地址:https://www.cnblogs.com/TreeDream/p/10002918.html
Copyright © 2011-2022 走看看