zoukankan      html  css  js  c++  java
  • 数据增强

    数据增强的方式有很多,比如对图像进行几何变换(如翻转、旋转、变形、缩放等)、颜色变换(包括噪声、模糊、颜色变换、檫除、填充等),将有限的数据,进行充分的利用。这里将介绍的仅仅是对图像数据进行任意方向的移动操作(上下左右)来扩充数据。

    这里将使用scipy中的shift变换工具(from scipy.ndimage.interpolation import shift)

       常用的参数:input输入图像数据为ndarray类型的,

            shift参数代表表示各个维度的偏移量[1,1]表示第一个第二个维度均偏移1,

            cval参数代表偏移后原来位置用什么来填充

    from scipy.ndimage.interpolation import shift
    def shift_digit(digit_array,dx,dy,new = 0):
        return shift(digit_array.reshape(28,28),[dy,dx],cval = new).reshape(784)
    plot_digit(shift_digit(some_digit,5,1,new =100))

       一个简单的数据偏移完成,接下来对整个训练集进行扩充

    X_train_expanded = [X_train]
    y_train_expanded = [y_train]
    for dx,dy in ((1,0),(-1,0),(0,1),(0,-1)):
        shifted_image = np.apply_along_axis(shift_digit,axis = 1,arr = X_train,dx = dx,dy = dy)
        X_train_expanded.append(shifted_image)
        y_train_expanded.append(y_train)
        
    X_train_expanded = np.concatenate(X_train_expanded)
    y_train_expanded = np.concatenate(y_train_expanded)
    X_train_expanded.shape,y_train_expanded.shape

       数据增加大了30万之多,有了更多的数据,接下来进行训练、预测,计算精度

    knn_clf.fit(X_train_expanded,y_train_expanded)

    y_knn_expanded_pred = knn_clf.predict(X_test)
    accuracy_score(y_test,y_knn_expanded_pred)

     另一种表示方式:

    def shift_image(image,dx,dy):
        image = image.reshape((28,28))
        shifted_image = shift(image,[dy,dx],cval = 0,mode = 'constant')
        return shifted_image.reshape([-1])
    image = X_train[1000]
    shifted_image_down = shift_image(image,0,5)
    shifted_image_left = shift_image(image,-5,0)
    
    plt.figure(figsize=(12,3))
    plt.subplot(131)
    plt.title("Original",fontsize= 14)
    plt.imshow(image.reshape(28,28),interpolation='nearest',cmap = 'Greys')
    plt.subplot(132)
    plt.title("shifted down",fontsize= 14)
    plt.imshow(shifted_image_down.reshape(28,28),interpolation='nearest',cmap = 'Greys')
    plt.subplot(133)
    plt.title("shifted left",fontsize= 14)
    plt.imshow(shifted_image_left.reshape(28,28),interpolation='nearest',cmap = 'Greys')
    plt.show()

    X_train_augmented = [image for image in X_train]
    y_train_augmented = [label for label in y_train]
    
    for dx,dy in ((1,0),(-1,0),(0,1),(0,-1)):
        for image,label in zip(X_train,y_train):
            X_train_augmented.append(shift_image(image,dx,dy))
            y_train_augmented.append(label)
    X_train_augmented = np.array(X_train_augmented)
    y_train_augmented = np.array(y_train_augmented)
    #打乱顺序
    shuffle_idx = np.random.permutation(len(X_train_augmented)) X_train_augmented = X_train_augmented[shuffle_idx] y_train_augmented = y_train_augmented[shuffle_idx]
    knn_clf = KNeighborsClassifier(**grid_search.best_params_)
    knn_clf.fit(X_train_augmented,y_train_augmented)
    y_pred = knn_clf.predict(X_test)
    accuracy_score(y_test,y_pred)

     此时准确率已达到97%以上

    关于knn_clf = KNeighborsClassifier(**grid_search.best_params_)中的**犯傻了很久,**代表着该参数中包含了多个参数,在C++中也会有这种参数表示,
    也可参看python中*args与**kwargs的介绍(https://pythontips.com/2013/08/04/args-and-kwargs-in-python-explained/)

    当然,scipy.ndimage.interpolation 也包含了其他的数据增强的方法,如旋转、缩放等(参考:https://blog.csdn.net/songchunxiao1991/article/details/88531086)

  • 相关阅读:
    Linux 系统中用户切换(su user与 su
    linux 用户打开进程数和文件数调整
    hive sql 语法详解
    iOS
    iOS
    MySQL的事务的处理
    iOS
    iOS AOP编程思想及实践
    iOS 静态库和动态库(库详解)
    iOS 沙盒目录结构及正确使用
  • 原文地址:https://www.cnblogs.com/whiteBear/p/12455432.html
Copyright © 2011-2022 走看看