zoukankan      html  css  js  c++  java
  • mnist数据集进行自编码

    """
    自动编码的核心就是各种全连接的组合,它是一种无监督的形式,因为他的标签是自己。
    """
    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    import torch.utils.data as Data
    import torchvision
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    from matplotlib import cm
    import numpy as np
    
    # 超参数
    EPOCH = 10
    BATCH_SIZE = 64
    LR = 0.005
    DOWNLOAD_MNIST = False
    N_TEST_IMG = 5
    
    # Mnist数据集
    train_data = torchvision.datasets.MNIST(
        root='./mnist/',
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=DOWNLOAD_MNIST,
    )
    
    print(train_data.train_data.size())     # (60000, 28, 28)
    print(train_data.train_labels.size())   # (60000)
    
    # 显示出一个例子
    plt.imshow(train_data.train_data[2].numpy(), cmap='gray')
    plt.title('%i' % train_data.train_labels[2])
    plt.show()
    
    # 将数据集分为多批数据
    train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
    
    # 搭建自编码网络框架
    class AutoEncoder(nn.Module):
        def __init__(self):
            super(AutoEncoder, self).__init__()
    
            self.encoder = nn.Sequential(
                nn.Linear(28*28, 128),
                nn.Tanh(),
                nn.Linear(128, 64),
                nn.Tanh(),
                nn.Linear(64, 12),
                nn.Tanh(),
                nn.Linear(12, 3),
            )
            self.decoder = nn.Sequential(
                nn.Linear(3, 12),
                nn.Tanh(),
                nn.Linear(12, 64),
                nn.Tanh(),
                nn.Linear(64, 128),
                nn.Tanh(),
                nn.Linear(128, 28*28),
                nn.Sigmoid(), # 将输出结果压缩到0到1之间,因为train_data的数据在0到1之间
            )
    
        def forward(self, x):
            encoded = self.encoder(x)
            decoded = self.decoder(encoded)
            return encoded, decoded
    
    autoencoder = AutoEncoder()
    
    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
    loss_func = nn.MSELoss()
    
    # initialize figure
    f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
    plt.ion()   # 设置为实时打印
    
    # 第一行是原始图片
    view_data = Variable(train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.)
    for i in range(N_TEST_IMG):
        a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray'); a[0][i].set_xticks(()); a[0][i].set_yticks(())
    
    for epoch in range(EPOCH):
        for step, (x, y) in enumerate(train_loader):
            b_x = Variable(x.view(-1, 28*28))
            b_y = Variable(x.view(-1, 28*28))
    
            encoded, decoded = autoencoder(b_x)
    
            loss = loss_func(decoded, b_y)
            optimizer.zero_grad()     # 将上一部的梯度清零
            loss.backward()           # 反向传播,计算梯度          
            optimizer.step()          # 优化网络中的各个参数
    
            if step % 100 == 0:
                print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0])
    
                # 第二行画出解码后的图片
                _, decoded_data = autoencoder(view_data)
                for i in range(N_TEST_IMG):
                    a[1][i].clear()
                    a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='gray')
                    a[1][i].set_xticks(()); a[1][i].set_yticks(())
                plt.draw(); plt.pause(0.05)
    
    plt.ioff()
    plt.show()
    
    # 可视化三维图
    view_data = Variable(train_data.train_data[:200].view(-1, 28*28).type(torch.FloatTensor)/255.)
    encoded_data, _ = autoencoder(view_data)
    fig = plt.figure(2); ax = Axes3D(fig)
    X, Y, Z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()
    values = train_data.train_labels[:200].numpy()
    for x, y, z, s in zip(X, Y, Z, values):
        c = cm.rainbow(int(255*s/9)); ax.text(x, y, z, s, backgroundcolor=c)
    ax.set_xlim(X.min(), X.max()); ax.set_ylim(Y.min(), Y.max()); ax.set_zlim(Z.min(), Z.max())
    plt.show()
  • 相关阅读:
    Photoshop 基础七 位图 矢量图 栅格化
    Photoshop 基础六 图层
    Warfare And Logistics UVALive
    Walk Through the Forest UVA
    Airport Express UVA
    Guess UVALive
    Play on Words UVA
    The Necklace UVA
    Food Delivery ZOJ
    Brackets Sequence POJ
  • 原文地址:https://www.cnblogs.com/czz0508/p/10347065.html
Copyright © 2011-2022 走看看