zoukankan      html  css  js  c++  java
  • LeNet 分类 FashionMNIST

    import mxnet as mx
    from mxnet import autograd, gluon, init, nd
    from mxnet.gluon import loss as gloss, nn
    from mxnet.gluon import data as gdata
    import time
    import sys
    
    net = nn.Sequential()
    net.add(nn.Conv2D(channels=6, kernel_size=5, activation='sigmoid'),
            nn.MaxPool2D(pool_size=2, strides=2),
            nn.Conv2D(channels=16, kernel_size=5, activation='sigmoid'),
            nn.MaxPool2D(pool_size=2, strides=2),
            # Dense 会默认将(批量大小,通道,高,宽)形状的输入转换成
            # (批量大小,通道 * 高 * 宽)形状的输入。
            nn.Dense(120, activation='sigmoid'),
            nn.Dense(84, activation='sigmoid'),
            nn.Dense(10))
    
    X = nd.random.uniform(shape=(1, 1, 28, 28))
    net.initialize()
    for layer in net:
        X = layer(X)
        print(layer.name, 'output shape:	', X.shape)
    
    # batch_size = 256
    # train_iter, test_iter = gb.load_data_fashion_mnist(batch_size=batch_size)
    mnist_train = gdata.vision.FashionMNIST(train=True)
    mnist_test = gdata.vision.FashionMNIST(train=False)
    
    batch_size = 256
    transformer = gdata.vision.transforms.ToTensor()
    if sys.platform.startswith('win'):
        num_workers = 0
    else:
        num_workers = 4
    
    # 小批量数据迭代器(在cpu上)
    train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size=batch_size, shuffle=True,
                                  num_workers=num_workers)
    test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size=batch_size, shuffle=False,
                                 num_workers=num_workers)
    
    def try_gpu4():
        try:
            ctx = mx.gpu()
            _ = nd.zeros((1,), ctx=ctx)
        except mx.base.MXNetError:
            ctx = mx.cpu()
        return ctx
    
    ctx = try_gpu4()
    
    def accuracy(y_hat,y):
        return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar()
    
    def evaluate_accuracy(data_iter, net, ctx):
        acc = nd.array([0], ctx=ctx)
        for X, y in data_iter:
            # 如果 ctx 是 GPU,将数据复制到 GPU 上。
            X, y = X.as_in_context(ctx), y.as_in_context(ctx)
            acc += accuracy(net(X), y)
        return acc.asscalar() / len(data_iter)
    
    def train(net, train_iter, test_iter, batch_size, trainer, ctx,
                  num_epochs):
        print('training on', ctx)
        loss = gloss.SoftmaxCrossEntropyLoss()
        for epoch in range(num_epochs):
            train_l_sum, train_acc_sum, start = 0, 0, time.time()
            for X, y in train_iter:
                X, y = X.as_in_context(ctx), y.as_in_context(ctx)
                with autograd.record():
                    y_hat = net(X)
                    l = loss(y_hat, y)
                l.backward()
                trainer.step(batch_size)
                train_l_sum += l.mean().asscalar()
                train_acc_sum += accuracy(y_hat, y)
            test_acc = evaluate_accuracy(test_iter, net, ctx)
            print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, '
                  'time %.1f sec' % (epoch + 1, train_l_sum / len(train_iter),
                                     train_acc_sum / len(train_iter),
                                     test_acc, time.time() - start))
    
    lr, num_epochs = 0.9, 200
    net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier())
    
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
    train(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs)

  • 相关阅读:
    电脑无损换硬盘,不用重装系统驱动的小技巧
    OSPF协议原理及配置5-LSA分析
    OSPF协议原理及配置3-邻居关系的建立
    OSPF协议原理及配置2-理解邻居和邻接关系
    我在华为写代码
    嵌入式未来发展
    blog to live,do not love to blog
    浮点数转换为新类型时必须做范围检查
    分享点干货
    基础C语言知识串串香14☞增补知识
  • 原文地址:https://www.cnblogs.com/TreeDream/p/10044055.html
Copyright © 2011-2022 走看看