zoukankan      html  css  js  c++  java
  • 吴裕雄 python神经网络 手写数字图片识别(5)

    import keras
    import matplotlib.pyplot as plt
    from keras.models import Sequential
    from keras.layers import Dense,Activation,Flatten,Dropout,Convolution2D,MaxPooling2D
    from keras.utils import np_utils
    from keras.optimizers import RMSprop
    from skimage import io

    nb_classes=10
    batch_size=128
    ####因为是卷积神经网络,输入数据的格式是图像格式,所以要进行reshape
    train_X = io.imread("E:\WaySign\0_0_colorrgb0.ppm")
    train_x=np.reshape(train_X,(train_X.shape[0],32,32,1))
    # test_x=np.reshape(test_X,(test_X.shape[0],28,28,1))
    # train_y=np_utils.to_categorical(train_Y,nb_classes)
    # test_y=np_utils.to_categorical(test_Y,nb_classes)

    print(train_y.shape,' ',test_y.shape)

    print(train_x.shape,' ',test_x.shape)

    train_x[:,:,:,0].shape

    ###reshape后的数据显示
    import matplotlib.pyplot as plt
    %matplotlib inline
    f,a=plt.subplots(1,10,figsize=(10,5))
    for i in range(10):
    a[i].imshow(train_x[i,:,:,0],cmap='gray')
    print(train_Y[i])

    ####establish a convolution nerual network
    model=Sequential()

    ####Convolution layer 1
    model.add(Convolution2D(filters=32,kernel_size=(3,3),input_shape=(28,28,1),strides=(1,1),
    padding='same',activation='relu'))

    #####pooling layer with dropout
    model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='valid'))
    model.add(Dropout(0.2))

    ####Convolution layer 2
    model.add(Convolution2D(filters=64,kernel_size=(3,3),strides=(1,1),padding='same',
    activation='relu'))
    model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='valid'))
    model.add(Dropout(0.2))

    ####Convolution layer 3
    model.add(Convolution2D(filters=128,kernel_size=(3,3),strides=(1,1),padding='same',
    activation='relu'))
    model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='valid'))
    model.add(Flatten())###理解扁平化
    model.add(Dropout(0.2))

    #model.add(Flatten())?

    ####fully connected layer 1 (fc layer)
    model.add(Dense(output_dim=625,activation='relu'))
    model.add(Dropout(0.5))

    ####fully connected layer 2 (fc layer)
    model.add(Dense(output_dim=10,activation='softmax'))
    model.summary()

    model.compile(optimizer=RMSprop(lr=0.001,rho=0.9),loss="categorical_crossentropy",
    metrics=['accuracy'])
    import time
    start_time=time.time()
    model.fit(train_x,train_y,epochs=30,batch_size=128,verbose=1)
    end_time=time.time()
    print("running time:%.2f"%(end_time-start_time))

    evaluation=model.evaluate(test_x,test_y,batch_size=128,verbose=1)
    print("model loss:%.4f"%(evaluation[0]),"model accuracy:%.4f"%(evaluation[1]))

    # https://github.com/fchollet/keras/issues/431
    def get_activations(model, model_inputs, print_shape_only=True, layer_name=None):
    import keras.backend as K
    print('----- activations -----')
    activations = []
    inp = model.input

    model_multi_inputs_cond = True
    if not isinstance(inp, list):
    # only one input! let's wrap it in a list.
    inp = [inp]
    model_multi_inputs_cond = False

    outputs = [layer.output for layer in model.layers if
    layer.name == layer_name or layer_name is None] # all layer outputs

    funcs = [K.function(inp + [K.learning_phase()], [out]) for out in outputs] # evaluation functions

    if model_multi_inputs_cond:
    list_inputs = []
    list_inputs.extend(model_inputs)
    list_inputs.append(1.)
    else:
    list_inputs = [model_inputs, 1.]

    # Learning phase. 1 = Test mode (no dropout or batch normalization)
    # layer_outputs = [func([model_inputs, 1.])[0] for func in funcs]
    layer_outputs = [func(list_inputs)[0] for func in funcs]
    for layer_activations in layer_outputs:
    activations.append(layer_activations)
    if print_shape_only:
    print(layer_activations.shape)
    else:
    print(layer_activations)
    return activations

    # https://github.com/philipperemy/keras-visualize-activations/blob/master/read_activations.py
    def display_activations(activation_maps):
    import numpy as np
    import matplotlib.pyplot as plt
    """
    (1, 28, 28, 32)
    (1, 14, 14, 32)
    (1, 14, 14, 32)
    (1, 14, 14, 64)
    (1, 7, 7, 64)
    (1, 7, 7, 64)
    (1, 7, 7, 128)
    (1, 3, 3, 128)
    (1, 1152)
    (1, 1152)
    (1, 625)
    (1, 625)
    (1, 10)
    """
    batch_size = activation_maps[0].shape[0]
    assert batch_size == 1, 'One image at a time to visualize.'
    for i, activation_map in enumerate(activation_maps):
    print('Displaying activation map {}'.format(i))
    shape = activation_map.shape
    if len(shape) == 4:
    activations = np.hstack(np.transpose(activation_map[0], (2, 0, 1)))
    elif len(shape) == 2:
    # try to make it square as much as possible. we can skip some activations.
    activations = activation_map[0]
    num_activations = len(activations)
    if num_activations > 1024: # too hard to display it on the screen.
    square_param = int(np.floor(np.sqrt(num_activations)))
    activations = activations[0: square_param * square_param]
    activations = np.reshape(activations, (square_param, square_param))
    else:
    activations = np.expand_dims(activations, axis=0)
    else:
    raise Exception('len(shape) = 3 has not been implemented.')
    #plt.imshow(activations, interpolation='None', cmap='binary')
    fig, ax = plt.subplots(figsize=(18, 12))
    ax.imshow(activations, interpolation='None', cmap='binary')
    plt.show()

    ###One image at a time to visualize.
    activations = get_activations(model, (test_x[0,:,:,:])[np.newaxis,:])

    (test_x[0,:,:,:])[np.newaxis,:].shape

    display_activations(activations)

    plt.imshow(test_x[0,:,:,0],cmap='gray')
    pred_value=model.predict_classes((test_x[0,:,:,:])[np.newaxis,:],batch_size=1)
    print(pred_value)

  • 相关阅读:
    《CLR Via C# 第3版》笔记之(十四) 泛型高级
    《CLR Via C# 第3版》笔记之(十三) 泛型基础
    AOP学习基于Emit和Attribute的简单AOP实现
    《CLR Via C# 第3版》笔记之(十五) 接口
    C# 连接Oracle(利用ODP.net,不安装oracle客户端)
    《CLR Via C# 第3版》笔记之(十七) 线程基础
    C#直接读取磁盘文件(类似linux的Direct IO模式)
    《CLR Via C# 第3版》笔记之(十六) 字符串
    [置顶] C#中通过调用webService获取上网IP地址的区域的方法
    Android中Socket通讯类
  • 原文地址:https://www.cnblogs.com/tszr/p/10085348.html
Copyright © 2011-2022 走看看