zoukankan      html  css  js  c++  java
  • keras02

    本项目参考:

    https://www.bilibili.com/video/av31500120?t=4657

    训练代码

      1 # coding: utf-8
      2 # Learning from Mofan and Mike G
      3 # Recreated by Paprikatree
      4 # Convolution NN Train
      5 
      6 import numpy as np
      7 from keras.datasets import mnist
      8 from keras.utils import np_utils
      9 from keras.models import Sequential
     10 from keras.layers import Convolution2D, Activation, MaxPool2D, Flatten, Dense
     11 from keras.optimizers import Adam
     12 from keras.models import load_model
     13 
     14 
     15 nb_class = 10
     16 nb_epoch = 4
     17 batchsize = 128
     18 
     19 '''
     20 1st,准备参数
     21 X_train: (0,255) --> (0,1) CNN中似乎没有必要?cnn自动转了吗?
     22 设置时间函数测试一下两者对比。
     23 小技巧:X_train /= 255.0 就可不用转换成浮点了???
     24 '''
     25 # Preparing your data mnist.  MAC /.keras/datasets  linux home ./keras/datasets
     26 (X_train, Y_train), (X_test, Y_test) = mnist.load_data()
     27 
     28 
     29 # setup data shape
     30 # (-1, 28, 28, 1) -1表示有默认个数据集,28*28是像素,1是1个通道
     31 X_train = X_train.reshape(-1, 28, 28, 1)  # tensorflow-channel last,while theano-channel first
     32 X_test = X_test.reshape(-1, 28, 28, 1)
     33 
     34 X_train = X_train/255.000
     35 X_test = X_test/255.000
     36 
     37 # One-hot 6 --> [0,0,0,0,0,1,0,0,0]
     38 Y_train = np_utils.to_categorical(Y_train, nb_class)
     39 Y_test = np_utils.to_categorical(Y_test, nb_class)
     40 
     41 '''
     42 2nd,设置模型
     43 '''
     44 
     45 # setup model
     46 model = Sequential()
     47 
     48 # 1st convolution layer # 滤波器要在28x28的图上横着走32次
     49 model.add(Convolution2D(
     50     filters=32,  # 此处把filters写成了filter,找了半天。囧
     51     kernel_size=[5, 5],  # 滤波器是5x5大小的,可以是list列表,也可以是tuple元祖
     52     padding='same',  # padding也是一个窗口模式
     53     input_shape=(28, 28, 1)  # 定义输入的数据,必须是元组
     54 ))
     55 model.add(Activation('relu'))
     56 model.add(MaxPool2D(
     57     pool_size=(2, 2),  # 按照规则抓取特征,此处为在pool_size的2*2窗口下,strides = 2*2 跳两格再抓取。如 1 2 3 4 5 6...27 28 抓取1 2 ,跳过 3 4 抓取 5 6。
     58     strides=(2, 2),  # 相当于把图片缩小了。
     59     padding="same",
     60 ))
     61 
     62 # 2nd Conv2D layer
     63 model.add(Convolution2D(
     64     filters=64,
     65     kernel_size=(5, 5),
     66     padding='same',
     67 ))
     68 model.add(Activation('relu'))
     69 model.add(MaxPool2D(
     70     pool_size=(2, 2),  # 按照规则抓取特征,此处为在pool_size的2*2窗口下,strides = 2*2 跳两格再抓取。如 1 2 3 4 5 6...27 28 抓取1 2 ,跳过 3 4 抓取 5 6。
     71     strides=(2, 2),  # 相当于把图片缩小了。
     72     padding="same",
     73 ))  # 讨论,卷积层数和最终结果关系。
     74 
     75 # 1st Fully connected Dense,Dense 全连接层是hello world里面的内容
     76 model.add(Flatten())  # 把卷积层里面的全部转换层一维数组
     77 model.add(Dense(1024))  # Dense is output
     78 model.add(Activation('relu'))
     79 
     80 # 1st Fully connected Dense,Dense 全连接层是hello world里面的内容
     81 # 把卷积层里面的全部转换层一维数组
     82 model.add(Dense(256))  # Dense is output
     83 model.add(Activation('tanh'))
     84 
     85 # 2nd Fully connected Dense
     86 model.add(Dense(10))
     87 model.add(Activation('softmax'))
     88 
     89 '''
     90 3rd 定义参数
     91 '''
     92 # Define Optimizer and setup Param
     93 adam = Adam(lr=0.0001)  # Adam实例化
     94 
     95 # compile model
     96 model.compile(
     97     optimizer=adam,  # optimizer='Adam'也是可以的,且默认lr=0.001,此处已经实例化为adam
     98     loss='categorical_crossentropy',
     99     metrics=['accuracy'],
    100 )
    101 
    102 # Run network
    103 model.fit(x=X_train,  # 更多参数可以查看fit函数,alt+鼠标左键单击fit
    104           y=Y_train,
    105           epochs=nb_epoch,
    106           batch_size=batchsize,  # p=parameter, batch_size; v=var, batch size
    107           verbose=1,  # 显示模式
    108           validation_data=(X_test, Y_test)
    109           )
    110 model.save('model_name.h5')
    111 # evaluation = model.evaluate(X_test, Y_test)  现在用model.fit(validation_data)
    112 # print(evaluation)  效果一样

    测试代码:

     1 # coding: utf-8
     2 # Learning from Mofan and Mike G
     3 # Recreated by Paprikatree
     4 # Convolution NN Predict
     5 
     6 import numpy as np
     7 from keras.models import load_model  # ??
     8 import matplotlib.pyplot as plt
     9 import matplotlib.image as processimage
    10 
    11 
    12 # load trained model
    13 model = load_model('model_name.h5')  # 已经训练好了的模型,在根目录下,默认为model_name.h5
    14 
    15 
    16 # 写一个来预测的类
    17 class MainPredictImg(object):
    18     
    19     def __init__(self):
    20         pass
    21     
    22     def pred(self, filename):
    23         pred_img = processimage.imread(filename)
    24         pred_img = np.array(pred_img)
    25         pred_img = pred_img.reshape(-1, 28, 28, 1)
    26         prediction = model.predict(pred_img)
    27         final_prediction = [result.argmax() for result in prediction][0]
    28         a = 0
    29         for i in prediction[0]:
    30             print(a)
    31             print('Percent:{:.30%}'.format(i))
    32             a = a+1
    33         return final_prediction
    34 
    35 
    36 def main():
    37     predict = MainPredictImg()
    38     res = predict.pred('4.png')
    39     print("your number is:-->", res)
    40 
    41 
    42 if __name__ == '__main__':
    43     main()
    View Code
  • 相关阅读:
    WPF 自定义NotifyPropertyChanged
    深度学习(五)正则化之L1和L2
    深度学习(三) 反向传播直观理解
    javascript中的原型和原型链(二)
    javascript中的原型和原型链(一)
    javascript中创建对象的方式及优缺点(二)
    javascript中创建对象的方式及优缺点(一)
    JS实现深拷贝的几种方法
    json.stringify()的妙用,json.stringify()与json.parse()的区别
    Javascript你必须要知道的知识点
  • 原文地址:https://www.cnblogs.com/paprikatree/p/10151591.html
Copyright © 2011-2022 走看看