zoukankan      html  css  js  c++  java
  • pytorch之 CNN

    
    
       ###仅为自己练习,没有其他用途



    1
    # library 2 # standard library 3 import os 4 5 # third-party library 6 import torch 7 import torch.nn as nn 8 import torch.utils.data as Data 9 import torchvision 10 import matplotlib.pyplot as plt 11 12 # torch.manual_seed(1) # reproducible 13 14 # Hyper Parameters 15 EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch 16 BATCH_SIZE = 50 17 LR = 0.001 # learning rate 18 DOWNLOAD_MNIST = False 19 20 21 # Mnist digits dataset 22 if not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'): 23 # not mnist dir or mnist is empyt dir 24 DOWNLOAD_MNIST = True 25 26 train_data = torchvision.datasets.MNIST( 27 root='./mnist/', 28 train=True, # this is training data 29 transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to 30 # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0] 31 download=DOWNLOAD_MNIST, 32 ) 33 34 # # plot one example 35 # print(train_data.train_data.size()) # (60000, 28, 28) 36 # print(train_data.train_labels.size()) # (60000) 37 # plt.imshow(train_data.train_data[0].numpy(), cmap='gray') 38 # plt.title('%i' % train_data.train_labels[0]) 39 # plt.show() 40 41 # Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28) 42 train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) 43 # pick 2000 samples to speed up testing 44 test_data = torchvision.datasets.MNIST(root='./mnist/', train=False) 45 test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1) 46 test_y = test_data.test_labels[:2000] 47 48 49 class CNN(nn.Module): 50 def __init__(self): 51 super(CNN, self).__init__() 52 self.conv1 = nn.Sequential( # input shape (1, 28, 28) 53 nn.Conv2d( 54 in_channels=1, # input height 55 out_channels=16, # n_filters 56 kernel_size=5, # filter size 57 stride=1, # filter movement/step 58 padding=2, # if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1 59 ), # output shape (16, 28, 28) 60 nn.ReLU(), # activation 61 nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14) 62 ) 63 self.conv2 = nn.Sequential( # input shape (16, 14, 14) 64 nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14) 65 nn.ReLU(), # activation 66 nn.MaxPool2d(2), # output shape (32, 7, 7) 67 ) 68 self.out = nn.Linear(32 * 7 * 7, 10) # fully connected layer, output 10 classes 69 70 def forward(self, x): 71 x = self.conv1(x) 72 x = self.conv2(x) 73 x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7) 74 output = self.out(x) 75 return output, x # return x for visualization 76 77 78 cnn = CNN() 79 print(cnn) # net architecture 80 81 optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters 82 loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted 83 84 # following function (plot_with_labels) is for visualization, can be ignored if not interested 85 from matplotlib import cm 86 try: from sklearn.manifold import TSNE; HAS_SK = True 87 except: HAS_SK = False; print('Please install sklearn for layer visualization') 88 def plot_with_labels(lowDWeights, labels): 89 plt.cla() 90 X, Y = lowDWeights[:, 0], lowDWeights[:, 1] 91 for x, y, s in zip(X, Y, labels): 92 c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9) 93 plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01) 94 95 plt.ion() 96 # training and testing 97 for epoch in range(EPOCH): 98 for step, (b_x, b_y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader 99 100 output = cnn(b_x)[0] # cnn output 101 loss = loss_func(output, b_y) # cross entropy loss 102 optimizer.zero_grad() # clear gradients for this training step 103 loss.backward() # backpropagation, compute gradients 104 optimizer.step() # apply gradients 105 106 if step % 50 == 0: 107 test_output, last_layer = cnn(test_x) 108 pred_y = torch.max(test_output, 1)[1].data.numpy() 109 accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0)) 110 print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy) 111 if HAS_SK: 112 # Visualization of trained flatten layer (T-SNE) 113 tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000) 114 plot_only = 500 115 low_dim_embs = tsne.fit_transform(last_layer.data.numpy()[:plot_only, :]) 116 labels = test_y.numpy()[:plot_only] 117 plot_with_labels(low_dim_embs, labels) 118 plt.ioff() 119 120 # print 10 predictions from test data 121 test_output, _ = cnn(test_x[:10]) 122 pred_y = torch.max(test_output, 1)[1].data.numpy() 123 print(pred_y, 'prediction number') 124 print(test_y[:10].numpy(), 'real number')
  • 相关阅读:
    (转载)SAPI 包含sphelper.h编译错误解决方案
    C++11标准的智能指针、野指针、内存泄露的理解(日后还会补充,先浅谈自己的理解)
    504. Base 7(LeetCode)
    242. Valid Anagram(LeetCode)
    169. Majority Element(LeetCode)
    100. Same Tree(LeetCode)
    171. Excel Sheet Column Number(LeetCode)
    168. Excel Sheet Column Title(LeetCode)
    122.Best Time to Buy and Sell Stock II(LeetCode)
    404. Sum of Left Leaves(LeetCode)
  • 原文地址:https://www.cnblogs.com/dhName/p/11759085.html
Copyright © 2011-2022 走看看