逻辑斯蒂回归是一种常用的分类方法,在这里使用Pytorch来实现逻辑斯蒂回归。
在这部分中使用MNIST数据集,数据中图片大小为28*28,且一共有10种标签。
下面开始构建模型。
第一部分:首先来介绍下Batch Size、Epoch和Iteration的概念。(这部分内容来源于https://zhuanlan.zhihu.com/p/157897691)
1. Batch Size
释义:批大小,即单次训练使用的样本数
为什么需要有 Batch_Size :batch size 的正确选择是为了在内存效率和内存容量之间寻找最佳平衡。
Batch size调参经验总结:
相对于正常数据集,如果Batch_Size过小,训练数据就会非常难收敛,从而导致underfitting。
增大Batch_Size,相对处理速度加快。
增大Batch_Size,所需内存容量增加(epoch的次数需要增加以达到最好结果)。
这里我们发现上面两个矛盾的问题,因为当epoch增加以后同样也会导致耗时增加从而速度下降。
2. Epoch
1个epoch表示过了1遍训练集中的所有样本,即表示所有训练样本的一次forward+一次 backward。
3. Iteration
Iteration在有的网络中也叫training step,中文或称之为“迭代”,具体来说:一次迭代 = 一次forward + 一次backward。换句话说,就是“取若干数据,通过网络推理得到结果,调整网络权值”这样整体的过程称为一次迭代。(多说一句,这里所取的若干数据,就是batch size所决定的)
迭代是重复反馈的动作,神经网络中我们希望通过迭代进行多次的训练以到达所需的目标或结果。
每次迭代后将更新1次网络结构的参数,每一次迭代得到的结果都会被作为下一次迭代的初始值。
4. Batch size,Epoch, Iteration的小结
总结:
一次epoch=所有训练数据forward+backward后更新参数的过程。
一次iteration= batch size个训练数据的forward+backward后更新参数过程。
这里举一个小栗子。假设有一个数据集含200000个样本,取1000次迭代为1个epoch,那么每次迭代的batch-size需要设为200。这样1个epoch相当于过了200000个训练样本。
第二部分:导入数据和展示数据(这部分来自https://www.kaggle.com/kanncaa1/pytorch-tutorial-for-deep-learning-lovers)
# Import Libraries import torch import torch.nn as nn from torch.autograd import Variable from torch.utils.data import DataLoader import pandas as pd from sklearn.model_selection import train_test_split # Prepare Dataset # load data train = pd.read_csv(r"../input/train.csv",dtype = np.float32) # split data into features(pixels) and labels(numbers from 0 to 9) targets_numpy = train.label.values features_numpy = train.loc[:,train.columns != "label"].values/255 # normalization # train test split. Size of train data is 80% and size of test data is 20%. features_train, features_test, targets_train, targets_test = train_test_split(features_numpy, targets_numpy, test_size = 0.2, random_state = 42) # create feature and targets tensor for train set. As you remember we need variable to accumulate gradients. Therefore first we create tensor, then we will create variable featuresTrain = torch.from_numpy(features_train) targetsTrain = torch.from_numpy(targets_train).type(torch.LongTensor) # data type is long # create feature and targets tensor for test set. featuresTest = torch.from_numpy(features_test) targetsTest = torch.from_numpy(targets_test).type(torch.LongTensor) # data type is long # batch_size, epoch and iteration batch_size = 100 n_iters = 10000 num_epochs = n_iters / (len(features_train) / batch_size) num_epochs = int(num_epochs) # Pytorch train and test sets train = torch.utils.data.TensorDataset(featuresTrain,targetsTrain) test = torch.utils.data.TensorDataset(featuresTest,targetsTest) # data loader train_loader = DataLoader(train, batch_size = batch_size, shuffle = False) test_loader = DataLoader(test, batch_size = batch_size, shuffle = False) # visualize one of the images in data set plt.imshow(features_numpy[10].reshape(28,28)) plt.axis("off") plt.title(str(targets_numpy[10])) plt.savefig('graph.png') plt.show()
结果为:
第三部分:构造模型,使用交叉熵损失函数
# Create Logistic Regression Model class LogisticRegressionModel(nn.Module): def __init__(self, input_dim, output_dim): super(LogisticRegressionModel, self).__init__() # Linear part self.linear = nn.Linear(input_dim, output_dim) # There should be logistic function right? # However logistic function in pytorch is in loss function # So actually we do not forget to put it, it is only at next parts def forward(self, x): out = self.linear(x) return out # Instantiate Model Class input_dim = 28*28 # size of image px*px output_dim = 10 # labels 0,1,2,3,4,5,6,7,8,9 # create logistic regression model model = LogisticRegressionModel(input_dim, output_dim) # Cross Entropy Loss error = nn.CrossEntropyLoss() # SGD Optimizer learning_rate = 0.001 optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
第四部分:训练模型
# Traning the Model count = 0 loss_list = [] iteration_list = [] for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader): # Define variables train = Variable(images.view(-1, 28*28)) labels = Variable(labels) # Clear gradients optimizer.zero_grad() # Forward propagation outputs = model(train) # Calculate softmax and cross entropy loss loss = error(outputs, labels) # Calculate gradients loss.backward() # Update parameters optimizer.step() count += 1 # Prediction if count % 50 == 0: # Calculate Accuracy correct = 0 total = 0 # Predict test dataset for images, labels in test_loader: test = Variable(images.view(-1, 28*28)) # Forward propagation outputs = model(test) # Get predictions from the maximum value predicted = torch.max(outputs.data, 1)[1] # Total number of labels total += len(labels) # Total correct predictions correct += (predicted == labels).sum() accuracy = 100 * correct / float(total) # store loss and iteration loss_list.append(loss.data) iteration_list.append(count) if count % 500 == 0: # Print Loss print('Iteration: {} Loss: {} Accuracy: {}%'.format(count, loss.data, accuracy))
结果:
Iteration: 500 Loss: 1.8399910926818848 Accuracy: 68% Iteration: 1000 Loss: 1.5982391834259033 Accuracy: 75% Iteration: 1500 Loss: 1.2930790185928345 Accuracy: 78% Iteration: 2000 Loss: 1.1937870979309082 Accuracy: 80% Iteration: 2500 Loss: 1.0323244333267212 Accuracy: 81% Iteration: 3000 Loss: 0.9379988312721252 Accuracy: 82% Iteration: 3500 Loss: 0.899523913860321 Accuracy: 82% Iteration: 4000 Loss: 0.7464531660079956 Accuracy: 83% Iteration: 4500 Loss: 0.9766625761985779 Accuracy: 83% Iteration: 5000 Loss: 0.8022621870040894 Accuracy: 83% Iteration: 5500 Loss: 0.7587511539459229 Accuracy: 84% Iteration: 6000 Loss: 0.8655218482017517 Accuracy: 84% Iteration: 6500 Loss: 0.6625986695289612 Accuracy: 84% Iteration: 7000 Loss: 0.7128363251686096 Accuracy: 84% Iteration: 7500 Loss: 0.6303086280822754 Accuracy: 85% Iteration: 8000 Loss: 0.7414441704750061 Accuracy: 85% Iteration: 8500 Loss: 0.5468852519989014 Accuracy: 85% Iteration: 9000 Loss: 0.6567560434341431 Accuracy: 85% Iteration: 9500 Loss: 0.5228758454322815 Accuracy: 85%
第五部分:可视化
# visualization plt.plot(iteration_list,loss_list) plt.xlabel("Number of iteration") plt.ylabel("Loss") plt.title("Logistic Regression: Loss vs Number of iteration") plt.show()