zoukankan      html  css  js  c++  java
  • Python深度学习笔记07--使用Keras建立卷积神经网络

     1 from keras.datasets import mnist
     2 from keras.utils import to_categorical
     3 
     4 #1. 获取数据
     5 (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
     6 
     7 #2. 处理数据
     8 train_images = train_images.reshape((60000, 28, 28, 1))
     9 train_images = train_images.astype('float32') / 255
    10 
    11 test_images = test_images.reshape((10000, 28, 28, 1))
    12 test_images = test_images.astype('float32') / 255
    13 
    14 train_labels = to_categorical(train_labels)
    15 test_labels = to_categorical(test_labels)
    16 
    17 from keras import layers
    18 from keras import models
    19 
    20 #3. 建立网络模型
    21 model = models.Sequential()
    22 model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    23 model.add(layers.MaxPooling2D((2, 2)))
    24 model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    25 model.add(layers.MaxPooling2D((2, 2)))
    26 model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    27 
    28 model.add(layers.Flatten())
    29 model.add(layers.Dense(64, activation='relu'))
    30 model.add(layers.Dense(10, activation='softmax'))
    31 
    32 # print(model.summary())
    33 
    34 '''
    35 #网络参数计算方式
    36 1.conv2d有320个参数,计算方式:(3 * 3 * 1 + 1) * 32
    37 2.conv2d_1有18496个参数,计算方式:(3 * 3 * 32 + 1) * 64
    38 3.conv2d_2有36928个参数,计算方式:(3 * 3 * 64 + 1) * 64
    39 4.dense有36928个参数,计算方式:(576 + 1) * 64
    40 5.dense_1有650个参数,计算方式:(64 + 1) * 10
    41 '''
    42 #4. 设置编译参数
    43 model.compile(optimizer='rmsprop',
    44               loss='categorical_crossentropy',
    45               metrics=['accuracy'])
    46 
    47 #5. 设置训练条件并训练             
    48 model.fit(train_images, train_labels, epochs=5, batch_size=64)
    49 
    50 #6. 评估模型
    51 test_loss, test_acc = model.evaluate(test_images, test_labels)
    52 #313/313 [==============================] - 1s 3ms/step - loss: 0.0324 - accuracy: 0.9898
    53 
    54 print(test_acc)
    55 #0.989799976348877

    5.1 卷积神经网络简介 

    5.1.1 卷积运算

    卷积神经网络有以下两个性质:

    (1)卷积神经网络学到的模式具有平移不变性。

    (2)卷积神经网络可以学到模式的空间层次结构i。

    卷积的两个关键参数:

    (1)从输入中提取的图块尺寸

    (2)输出特征的深度

    Keras的API为:Conv2D(output_depth, window_height, window_width)

    注意:书上这部分对卷积过程的描述不是很容易理解,建议看吴恩达的视频来学习卷积网络相关的概念。

    吴恩达卷积神经网络课程地址

    5.1.2 最大池化运算

    最大池化是从输入特征图中提取窗口,并输出每个通道的最大值。

    最大池化通常使用2 * 2的窗口和步幅2,其目的是将特征图下采样2倍。

    5.2 在小型数据集上从头开始训练一个卷积神经网络

    卷积神经网络代码如下:

      1 import os, shutil
      2 
      3 # The path to the directory where the original
      4 # dataset was uncompressed
      5 # original_dataset_dir = '/Users/fchollet/Downloads/kaggle_original_data'
      6 original_dataset_dir = 'E:\desktop\code\data\dogs-vs-cats\train'
      7 
      8 # The directory where we will
      9 # store our smaller dataset
     10 # base_dir = '/Users/fchollet/Downloads/cats_and_dogs_small'
     11 base_dir = 'E:\desktop\code\data\dogs-vs-cats\cats_and_dogs_small'
     12 # os.mkdir(base_dir)
     13 
     14 # Directories for our training,
     15 # validation and test splits
     16 train_dir = os.path.join(base_dir, 'train')
     17 # os.mkdir(train_dir)
     18 validation_dir = os.path.join(base_dir, 'validation')
     19 # os.mkdir(validation_dir)
     20 test_dir = os.path.join(base_dir, 'test')
     21 # os.mkdir(test_dir)
     22 
     23 # Directory with our training cat pictures
     24 train_cats_dir = os.path.join(train_dir, 'cats')
     25 # os.mkdir(train_cats_dir)
     26 
     27 # Directory with our training dog pictures
     28 train_dogs_dir = os.path.join(train_dir, 'dogs')
     29 # os.mkdir(train_dogs_dir)
     30 
     31 # Directory with our validation cat pictures
     32 validation_cats_dir = os.path.join(validation_dir, 'cats')
     33 # os.mkdir(validation_cats_dir)
     34 
     35 # Directory with our validation dog pictures
     36 validation_dogs_dir = os.path.join(validation_dir, 'dogs')
     37 # os.mkdir(validation_dogs_dir)
     38 
     39 # Directory with our validation cat pictures
     40 test_cats_dir = os.path.join(test_dir, 'cats')
     41 # os.mkdir(test_cats_dir)
     42 
     43 # Directory with our validation dog pictures
     44 test_dogs_dir = os.path.join(test_dir, 'dogs')
     45 # os.mkdir(test_dogs_dir)
     46 
     47 
     48 # #验证图片存放是否正确
     49 # print('total training cat images:', len(os.listdir(train_cats_dir)))
     50 
     51 # print('total training dog images:', len(os.listdir(train_dogs_dir)))
     52 
     53 # print('total validation cat images:', len(os.listdir(validation_cats_dir)))
     54 
     55 # print('total validation dog images:', len(os.listdir(validation_dogs_dir)))
     56 
     57 # print('total test cat images:', len(os.listdir(test_cats_dir)))
     58 
     59 # print('total test dog images:', len(os.listdir(test_dogs_dir)))
     60 
     61 
     62 from keras import layers
     63 from keras import models
     64 
     65 model = models.Sequential()
     66 model.add(layers.Conv2D(32, (3, 3), activation='relu',
     67                         input_shape=(150, 150, 3)))
     68 model.add(layers.MaxPooling2D((2, 2)))
     69 model.add(layers.Conv2D(64, (3, 3), activation='relu'))
     70 model.add(layers.MaxPooling2D((2, 2)))
     71 model.add(layers.Conv2D(128, (3, 3), activation='relu'))
     72 model.add(layers.MaxPooling2D((2, 2)))
     73 model.add(layers.Conv2D(128, (3, 3), activation='relu'))
     74 model.add(layers.MaxPooling2D((2, 2)))
     75 model.add(layers.Flatten())
     76 model.add(layers.Dense(512, activation='relu'))
     77 model.add(layers.Dense(1, activation='sigmoid'))
     78 
     79 from keras import optimizers
     80 
     81 model.compile(loss='binary_crossentropy',
     82               optimizer=optimizers.RMSprop(lr=1e-4),
     83               metrics=['acc'])
     84 
     85 
     86 from keras.preprocessing.image import ImageDataGenerator
     87 
     88 # All images will be rescaled by 1./255
     89 train_datagen = ImageDataGenerator(rescale=1./255)
     90 test_datagen = ImageDataGenerator(rescale=1./255)
     91 
     92 train_generator = train_datagen.flow_from_directory(
     93         # This is the target directory
     94         train_dir,
     95         # All images will be resized to 150x150
     96         target_size=(150, 150),
     97         batch_size=20,
     98         # Since we use binary_crossentropy loss, we need binary labels
     99         class_mode='binary')
    100 
    101 validation_generator = test_datagen.flow_from_directory(
    102         validation_dir,
    103         target_size=(150, 150),
    104         batch_size=20,
    105         class_mode='binary')        
    106 
    107 history = model.fit_generator(
    108       train_generator,
    109       steps_per_epoch=100,
    110       epochs=30,
    111       validation_data=validation_generator,
    112       validation_steps=50)       
    113 
    114 model.save('cats_and_dogs_small_1.h5')
    115 # model.load_weights('cats_and_dogs_small_1.h5')
    116 
    117 
    118 
    119 import matplotlib.pyplot as plt
    120 
    121 # #dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy'])
    122 acc = history.history['acc']
    123 val_acc = history.history['val_acc']
    124 loss = history.history['loss']
    125 val_loss = history.history['val_loss']
    126 
    127 epochs = range(len(acc))
    128 
    129 plt.plot(epochs, acc, 'bo', label='Training acc')
    130 plt.plot(epochs, val_acc, 'b', label='Validation acc')
    131 plt.title('Training and validation accuracy')
    132 plt.legend()
    133 
    134 plt.figure()
    135 
    136 plt.plot(epochs, loss, 'bo', label='Training loss')
    137 plt.plot(epochs, val_loss, 'b', label='Validation loss')
    138 plt.title('Training and validation loss')
    139 plt.legend()
    140 
    141 plt.show()
  • 相关阅读:
    java 代码规范 sun 公司
    软引用、弱引用、虚引用
    socket
    httpURLConnection、URL、httpClient、httpPost、httpGet
    android service aidl 理解
    Python2.7-codecs
    Python2.7-textwrap
    Python2.7-StringIO和cStringIO
    Python2.7-difflib
    Python2.7-struct模块
  • 原文地址:https://www.cnblogs.com/asenyang/p/14321973.html
Copyright © 2011-2022 走看看