zoukankan      html  css  js  c++  java
  • LeNet 分类 FashionMNIST

    import mxnet as mx
    from mxnet import autograd, gluon, init, nd
    from mxnet.gluon import loss as gloss, nn
    from mxnet.gluon import data as gdata
    import time
    import sys
    
    net = nn.Sequential()
    net.add(nn.Conv2D(channels=6, kernel_size=5, activation='sigmoid'),
            nn.MaxPool2D(pool_size=2, strides=2),
            nn.Conv2D(channels=16, kernel_size=5, activation='sigmoid'),
            nn.MaxPool2D(pool_size=2, strides=2),
            # Dense 会默认将(批量大小,通道,高,宽)形状的输入转换成
            # (批量大小,通道 * 高 * 宽)形状的输入。
            nn.Dense(120, activation='sigmoid'),
            nn.Dense(84, activation='sigmoid'),
            nn.Dense(10))
    
    X = nd.random.uniform(shape=(1, 1, 28, 28))
    net.initialize()
    for layer in net:
        X = layer(X)
        print(layer.name, 'output shape:	', X.shape)
    
    # batch_size = 256
    # train_iter, test_iter = gb.load_data_fashion_mnist(batch_size=batch_size)
    mnist_train = gdata.vision.FashionMNIST(train=True)
    mnist_test = gdata.vision.FashionMNIST(train=False)
    
    batch_size = 256
    transformer = gdata.vision.transforms.ToTensor()
    if sys.platform.startswith('win'):
        num_workers = 0
    else:
        num_workers = 4
    
    # 小批量数据迭代器(在cpu上)
    train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size=batch_size, shuffle=True,
                                  num_workers=num_workers)
    test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size=batch_size, shuffle=False,
                                 num_workers=num_workers)
    
    def try_gpu4():
        try:
            ctx = mx.gpu()
            _ = nd.zeros((1,), ctx=ctx)
        except mx.base.MXNetError:
            ctx = mx.cpu()
        return ctx
    
    ctx = try_gpu4()
    
    def accuracy(y_hat,y):
        return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar()
    
    def evaluate_accuracy(data_iter, net, ctx):
        acc = nd.array([0], ctx=ctx)
        for X, y in data_iter:
            # 如果 ctx 是 GPU,将数据复制到 GPU 上。
            X, y = X.as_in_context(ctx), y.as_in_context(ctx)
            acc += accuracy(net(X), y)
        return acc.asscalar() / len(data_iter)
    
    def train(net, train_iter, test_iter, batch_size, trainer, ctx,
                  num_epochs):
        print('training on', ctx)
        loss = gloss.SoftmaxCrossEntropyLoss()
        for epoch in range(num_epochs):
            train_l_sum, train_acc_sum, start = 0, 0, time.time()
            for X, y in train_iter:
                X, y = X.as_in_context(ctx), y.as_in_context(ctx)
                with autograd.record():
                    y_hat = net(X)
                    l = loss(y_hat, y)
                l.backward()
                trainer.step(batch_size)
                train_l_sum += l.mean().asscalar()
                train_acc_sum += accuracy(y_hat, y)
            test_acc = evaluate_accuracy(test_iter, net, ctx)
            print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, '
                  'time %.1f sec' % (epoch + 1, train_l_sum / len(train_iter),
                                     train_acc_sum / len(train_iter),
                                     test_acc, time.time() - start))
    
    lr, num_epochs = 0.9, 200
    net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier())
    
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
    train(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs)

  • 相关阅读:
    ORACLE 11g RAC-RAC DG Duplicate 搭建(生产操作文档)
    1.kafka是什么
    11.扩展知识-redis持久化
    10.Redis-服务器命令
    9.扩展知识-redis批量操作-事务(了解)
    8.扩展知识-多数据库(了解)
    7.Redis扩展知识-消息订阅与发布(了解)
    K8S上部署ES集群报错
    ORM 常用字段和参数
    celery的使用
  • 原文地址:https://www.cnblogs.com/TreeDream/p/10044055.html
Copyright © 2011-2022 走看看