数据增强的方式有很多,比如对图像进行几何变换(如翻转、旋转、变形、缩放等)、颜色变换(包括噪声、模糊、颜色变换、檫除、填充等),将有限的数据,进行充分的利用。这里将介绍的仅仅是对图像数据进行任意方向的移动操作(上下左右)来扩充数据。
这里将使用scipy中的shift变换工具(from scipy.ndimage.interpolation import shift)
常用的参数:input输入图像数据为ndarray类型的,
shift参数代表表示各个维度的偏移量[1,1]表示第一个第二个维度均偏移1,
cval参数代表偏移后原来位置用什么来填充
from scipy.ndimage.interpolation import shift def shift_digit(digit_array,dx,dy,new = 0): return shift(digit_array.reshape(28,28),[dy,dx],cval = new).reshape(784) plot_digit(shift_digit(some_digit,5,1,new =100))
一个简单的数据偏移完成,接下来对整个训练集进行扩充
X_train_expanded = [X_train] y_train_expanded = [y_train] for dx,dy in ((1,0),(-1,0),(0,1),(0,-1)): shifted_image = np.apply_along_axis(shift_digit,axis = 1,arr = X_train,dx = dx,dy = dy) X_train_expanded.append(shifted_image) y_train_expanded.append(y_train) X_train_expanded = np.concatenate(X_train_expanded) y_train_expanded = np.concatenate(y_train_expanded) X_train_expanded.shape,y_train_expanded.shape
数据增加大了30万之多,有了更多的数据,接下来进行训练、预测,计算精度
knn_clf.fit(X_train_expanded,y_train_expanded)
y_knn_expanded_pred = knn_clf.predict(X_test)
accuracy_score(y_test,y_knn_expanded_pred)
另一种表示方式:
def shift_image(image,dx,dy): image = image.reshape((28,28)) shifted_image = shift(image,[dy,dx],cval = 0,mode = 'constant') return shifted_image.reshape([-1])
image = X_train[1000] shifted_image_down = shift_image(image,0,5) shifted_image_left = shift_image(image,-5,0) plt.figure(figsize=(12,3)) plt.subplot(131) plt.title("Original",fontsize= 14) plt.imshow(image.reshape(28,28),interpolation='nearest',cmap = 'Greys') plt.subplot(132) plt.title("shifted down",fontsize= 14) plt.imshow(shifted_image_down.reshape(28,28),interpolation='nearest',cmap = 'Greys') plt.subplot(133) plt.title("shifted left",fontsize= 14) plt.imshow(shifted_image_left.reshape(28,28),interpolation='nearest',cmap = 'Greys') plt.show()
X_train_augmented = [image for image in X_train] y_train_augmented = [label for label in y_train] for dx,dy in ((1,0),(-1,0),(0,1),(0,-1)): for image,label in zip(X_train,y_train): X_train_augmented.append(shift_image(image,dx,dy)) y_train_augmented.append(label)
X_train_augmented = np.array(X_train_augmented)
y_train_augmented = np.array(y_train_augmented)
#打乱顺序
shuffle_idx = np.random.permutation(len(X_train_augmented)) X_train_augmented = X_train_augmented[shuffle_idx] y_train_augmented = y_train_augmented[shuffle_idx]
knn_clf = KNeighborsClassifier(**grid_search.best_params_)
knn_clf.fit(X_train_augmented,y_train_augmented)
y_pred = knn_clf.predict(X_test)
accuracy_score(y_test,y_pred)
此时准确率已达到97%以上
关于knn_clf = KNeighborsClassifier(**grid_search.best_params_)中的**犯傻了很久,**代表着该参数中包含了多个参数,在C++中也会有这种参数表示,
也可参看python中*args与**kwargs的介绍(https://pythontips.com/2013/08/04/args-and-kwargs-in-python-explained/)
当然,scipy.ndimage.interpolation 也包含了其他的数据增强的方法,如旋转、缩放等(参考:https://blog.csdn.net/songchunxiao1991/article/details/88531086)