zoukankan      html  css  js  c++  java
  • 11.绘制网络结构

    import numpy as np
    from keras.datasets import mnist
    from keras.utils import np_utils
    from keras.models import Sequential
    from keras.layers import Dense,Dropout,Convolution2D,MaxPooling2D,Flatten
    from keras.optimizers import Adam
    from keras.utils.vis_utils import plot_model
    import matplotlib.pyplot as plt 
    # install pydot and graphviz
     1 # 载入数据
     2 (x_train,y_train),(x_test,y_test) = mnist.load_data()
     3 # (60000,28,28)->(60000,28,28,1)
     4 x_train = x_train.reshape(-1,28,28,1)/255.0
     5 x_test = x_test.reshape(-1,28,28,1)/255.0
     6 # 换one hot格式
     7 y_train = np_utils.to_categorical(y_train,num_classes=10)
     8 y_test = np_utils.to_categorical(y_test,num_classes=10)
     9 
    10 # 定义顺序模型
    11 model = Sequential()
    12 
    13 # 第一个卷积层
    14 # input_shape 输入平面
    15 # filters 卷积核/滤波器个数
    16 # kernel_size 卷积窗口大小
    17 # strides 步长
    18 # padding padding方式 same/valid
    19 # activation 激活函数
    20 model.add(Convolution2D(
    21     input_shape = (28,28,1),
    22     filters = 32,
    23     kernel_size = 5,
    24     strides = 1,
    25     padding = 'same',
    26     activation = 'relu',
    27     name = 'conv1'
    28 ))
    29 # 第一个池化层
    30 model.add(MaxPooling2D(
    31     pool_size = 2,
    32     strides = 2,
    33     padding = 'same',
    34     name = 'pool1'
    35 ))
    36 # 第二个卷积层
    37 model.add(Convolution2D(64,5,strides=1,padding='same',activation = 'relu',name='conv2'))
    38 # 第二个池化层
    39 model.add(MaxPooling2D(2,2,'same',name='pool2'))
    40 # 把第二个池化层的输出扁平化为1维
    41 model.add(Flatten())
    42 # 第一个全连接层
    43 model.add(Dense(1024,activation = 'relu'))
    44 # Dropout
    45 model.add(Dropout(0.5))
    46 # 第二个全连接层
    47 model.add(Dense(10,activation='softmax'))
    48 
    49 # # 定义优化器
    50 # adam = Adam(lr=1e-4)
    51 
    52 # # 定义优化器,loss function,训练过程中计算准确率
    53 # model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])
    54 
    55 # # 训练模型
    56 # model.fit(x_train,y_train,batch_size=64,epochs=1)
    57 
    58 # # 评估模型
    59 # loss,accuracy = model.evaluate(x_test,y_test)
    60 
    61 # print('test loss',loss)
    62 # print('test accuracy',accuracy)
    plot_model(model,to_file="model.png",show_shapes=True,show_layer_names=True,rankdir='TB')
    plt.figure(figsize=(10,10))
    img = plt.imread("model.png")
    plt.imshow(img)
    plt.axis('off')
    plt.show()

  • 相关阅读:
    驼峰匹配
    常量
    bug生命周期&bug跟踪处理
    jmeter——参数化、关联、断言
    jmeter——http、jdbc、soap请求
    APP测试要点
    Android ADB 命令总结
    理解HTTP三次握手和四次握手的过程
    web、pc客户端、app测试的区别
    在RobotFramework--RIDE中把日期转化为整型进行运算
  • 原文地址:https://www.cnblogs.com/liuwenhua/p/11567052.html
Copyright © 2011-2022 走看看