import torch
import torch.utils.data as Data
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
#torch.manual_seed(1) # reproducible
# Hyper Parameters
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
batch_size = 4
LR = 0.001 # learning rate
DOWNLOAD_MNIST = True # set to False if you have downloaded
classes = ['0','1','2','3','4','5','6','7','8','9']
# Mnist digits dataset
train_dataset = torchvision.datasets.MNIST(
root='../../data/',
train=True, # this is training data
transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
download=DOWNLOAD_MNIST, # download it if you don't have it
)
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
print(type(train_dataset))
def imshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg,(1,2,0)))
plt.show()
dataiter = iter(train_loader)
images,labels = dataiter.next()
print(images.shape)#[50,1,28,28]
imshow(torchvision.utils.make_grid(images))
print(''.join('%5s' % classes[labels[j]] for j in range(4)))