zoukankan      html  css  js  c++  java
  • 机器学习算法(5):卷积神经网络原理及其keras实现

    1.原理

    CNN的资料特别多,这里不再赘述,仅收集相关的资料供大家参考:

    a.Deep learning:五十一(CNN的反向求导及练习)

    b.Deep Learning

    2.实现

    我们使用keras实现CNN,Keras的使用文档请参考

    a.Keras中文文档

    b.Keras英文文档

    参考keras官方的例子,我们使用keras,对数据集mnist训练一个cnn模型,实现的代码如下:

    '''Trains a simple convnet on the MNIST dataset.
    
    Gets to 99.25% test accuracy after 12 epochs
    (there is still a lot of margin for parameter tuning).
    16 seconds per epoch on a GRID K520 GPU.
    '''
    
    from __future__ import print_function
    import keras
    from keras.datasets import mnist
    from keras.models import Sequential
    from keras.layers import Dense, Dropout, Flatten
    from keras.layers import Conv2D, MaxPooling2D
    from keras import backend as K
    
    batch_size = 128
    num_classes = 10
    epochs = 12
    
    # input image dimensions
    img_rows, img_cols = 28, 28
    
    # the data, shuffled and split between train and test sets
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    
    if K.image_data_format() == 'channels_first':
        x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
        x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
        input_shape = (1, img_rows, img_cols)
    else:
        x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
        x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
        input_shape = (img_rows, img_cols, 1)
    
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    print('x_train shape:', x_train.shape)
    print(x_train.shape[0], 'train samples')
    print(x_test.shape[0], 'test samples')
    
    # convert class vectors to binary class matrices
    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)
    
    model = Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3),
                     activation='relu',
                     input_shape=input_shape))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(num_classes, activation='softmax'))
    
    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=keras.optimizers.Adadelta(),
                  metrics=['accuracy'])
    
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1,
              validation_data=(x_test, y_test))
    score = model.evaluate(x_test, y_test, verbose=0)
    print('Test loss:', score[0])
    print('Test accuracy:', score[1])
  • 相关阅读:
    Tomcat单独部署,控制台乱码解决方法
    mysql授权访问数据库
    Arrays.binarySearch采坑记录及用法
    使用Spring Ehcache二级缓存优化查询性能
    Redis批量删除缓存数据
    Java并发包之Semaphore用法
    Java并发包之CountDownLatch用法
    如何用Xshell导出文件到桌面本地
    Semaphore信号量原理
    老应用链接替换到新链接
  • 原文地址:https://www.cnblogs.com/cv-pr/p/7141857.html
Copyright © 2011-2022 走看看