zoukankan      html  css  js  c++  java
  • gluon 实现多层感知机MLP分类FashionMNIST

    from mxnet import gluon,init
    from mxnet.gluon import loss as gloss, nn
    from mxnet.gluon import data as gdata
    from mxnet import nd,autograd
    import gluonbook as gb
    
    import sys
    
    # 读取数据
    # 读取数据
    mnist_train = gdata.vision.FashionMNIST(train=True)
    mnist_test = gdata.vision.FashionMNIST(train=False)
    
    batch_size = 256
    transformer = gdata.vision.transforms.ToTensor()
    if sys.platform.startswith('win'):
        num_workers = 0
    else:
        num_workers = 4
    
    # 小批量数据迭代器
    train_iter = gdata.DataLoader(mnist_train.transform_first(transformer),batch_size=batch_size,shuffle=True,num_workers=num_workers)
    test_iter = gdata.DataLoader(mnist_test.transform_first(transformer),batch_size=batch_size,shuffle=False,num_workers=num_workers)
    
    # 定义网络
    net = nn.Sequential()
    net.add(nn.Dense(256,activation='relu'),nn.Dense(10))
    net.initialize(init.Normal(sigma=0.01))
    
    # 损失函数
    loss = gloss.SoftmaxCrossEntropyLoss()
    trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':0.5})
    
    
    def accuracy(y_hat, y):
        return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar()
    
    def evaluate_accuracy(data_iter, net):
        acc = 0
        for X, y in data_iter:
            acc += accuracy(net(X), y)
        return acc / len(data_iter)
    
    num_epochs = 5
    
    def train(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,trainer=None):
        for epoch in range(num_epochs):
            train_l_sum = 0
            train_acc_sum = 0
            for X,y in train_iter:
                with autograd.record():
                    y_hat = net(X)
                    l = loss(y_hat,y)
                l.backward()
    
                if trainer is None:
                    gb.sgd(params,lr,batch_size)
                else:
                    trainer.step(batch_size)
    
                train_l_sum += l.mean().asscalar()
    
    
            test_acc = evaluate_accuracy(test_iter,net)
            print('epoch %d,loss %.4f,test acc %.3f'%(epoch+1,train_l_sum / len(train_iter),test_acc))
    
    train(net,train_iter,test_iter,loss,num_epochs,batch_size,None,None,trainer)

  • 相关阅读:
    大数据培训:分享大数据行业就业趋势
    大数据培训:Zookeeper集群管理与选举
    【编码】UnicodeEncodeError: 'gbk' codec can't encode character '\xa0' in position XXX
    MVC 登录后重定向回最初请求的 URL FormsAuthentication.RedirectFromLoginPage
    EasyUI 下载与引用
    EntityFrameWork Parameter '@columnType' must be defined.
    Hello World
    protobuf windows java 环境搭建
    android XML转义字符
    shiro Remember 1.2.4反序列化漏洞
  • 原文地址:https://www.cnblogs.com/TreeDream/p/10033557.html
Copyright © 2011-2022 走看看