zoukankan      html  css  js  c++  java
  • 机器学习笔记(6):多类逻辑回归-使用gluon

    上一篇演示了纯手动添加隐藏层,这次使用gluon让代码更精减,代码来自:https://zh.gluon.ai/chapter_supervised-learning/mlp-gluon.html

    from mxnet import gluon
    from mxnet import ndarray as nd
    import matplotlib.pyplot as plt
    import mxnet as mx
    from mxnet import autograd
      
    def transform(data, label):
        return data.astype('float32')/255, label.astype('float32')
      
    mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)
    mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)
      
    def show_images(images):
        n = images.shape[0]
        _, figs = plt.subplots(1, n, figsize=(15, 15))
        for i in range(n):
            figs[i].imshow(images[i].reshape((28, 28)).asnumpy())
            figs[i].axes.get_xaxis().set_visible(False)
            figs[i].axes.get_yaxis().set_visible(False)
        plt.show()
    
    def get_text_labels(label):
        text_labels = [
            'T 恤', '长 裤', '套头衫', '裙 子', '外 套',
            '凉 鞋', '衬 衣', '运动鞋', '包 包', '短 靴'
        ]
        return [text_labels[int(i)] for i in label]
      
    data, label = mnist_train[0:10]
      
    print('example shape: ', data.shape, 'label:', label)
    show_images(data)
    print(get_text_labels(label))
      
    batch_size = 256
    train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)
    test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False)
      
    #计算模型
    net = gluon.nn.Sequential()
    with net.name_scope():
        net.add(gluon.nn.Flatten())
        net.add(gluon.nn.Dense(256, activation="relu"))
        net.add(gluon.nn.Dense(10))
    net.initialize()
      
    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
    
    #定义训练器
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.5})
     
    def accuracy(output, label):
        return nd.mean(output.argmax(axis=1) == label).asscalar()
      
    def _get_batch(batch):
        if isinstance(batch, mx.io.DataBatch):
            data = batch.data[0]
            label = batch.label[0]
        else:
            data, label = batch
        return data, label
      
    def evaluate_accuracy(data_iterator, net):
        acc = 0.
        if isinstance(data_iterator, mx.io.MXDataIter):
            data_iterator.reset()
        for i, batch in enumerate(data_iterator):
            data, label = _get_batch(batch)
            output = net(data)
            acc += accuracy(output, label)
        return acc / (i+1)
      
    for epoch in range(5):
        train_loss = 0.
        train_acc = 0.
        for data, label in train_data:
            with autograd.record():
                output = net(data)
                loss = softmax_cross_entropy(output, label)
            loss.backward()
            trainer.step(batch_size) #使用训练器,向"前"走一步
    
            train_loss += nd.mean(loss).asscalar()
            train_acc += accuracy(output, label)
    
        test_acc = evaluate_accuracy(test_data, net)
        print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (
            epoch, train_loss/len(train_data), train_acc/len(train_data), test_acc))
    
    data, label = mnist_test[0:10]
    show_images(data)
    print('true labels')
    print(get_text_labels(label))
      
    predicted_labels = net(data).argmax(axis=1)
    print('predicted labels')
    print(get_text_labels(predicted_labels.asnumpy()))
    

     有变化的地方,已经加上了注释。运行效果,跟一篇完全相同,就不重复贴图了

  • 相关阅读:
    [Java复习] 面试突击
    [Java复习] 面试突击
    [Java复习] 面试突击
    [Java复习] 面试突击
    [Java复习] 面试突击
    [Java复习]分布式事务 TCC RocketMQ最终一致性
    [Java复习]架构部署 超时重试 幂等防重
    [Java复习] 网关 灰度发布
    待查看
    CentOS7下搭建基本LNMP环境,部署WordPress
  • 原文地址:https://www.cnblogs.com/yjmyzz/p/8128122.html
Copyright © 2011-2022 走看看