zoukankan      html  css  js  c++  java
  • CNN特征点训练和识别基于python3.7

    环境:python3.7  

    疫情期间在家无聊就学了下cnn对集装箱的箱门特征进行了训练,基本200张图片就可以达到95%准确率

     代码如下:

    from keras.models import Sequential
    from keras.layers import Conv2D, MaxPool2D, Activation, Dropout, Flatten, Dense
    from keras.optimizers import Adam
    from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
    from keras.models import load_model
    import numpy as np

    # define model
    model = Sequential()
    model.add(Conv2D(input_shape=(150, 150, 3), filters=32, kernel_size=3, padding='same', activation='relu'))
    model.add(Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'))
    model.add(MaxPool2D(pool_size=2, strides=2))

    model.add(Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'))
    model.add(Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'))
    model.add(MaxPool2D(pool_size=2, strides=2))

    model.add(Conv2D(filters=128, kernel_size=3, padding='same', activation='relu'))
    model.add(Conv2D(filters=128, kernel_size=3, padding='same', activation='relu'))
    model.add(MaxPool2D(pool_size=2, strides=2))

    model.add(Flatten())
    model.add(Dense(64, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(2, activation='softmax'))

    # define optimizer
    adam = Adam(lr=1e-4)

    # define optimizer, value function, calculate accuracy
    model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])

    train_datagen = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    rescale=1/255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
    )

    test_datagen = ImageDataGenerator(
    rescale=1/255
    )

    batch_size = 32

    # create train data
    train_generator = train_datagen.flow_from_directory(
    'test/train',
    target_size=(150, 150),
    batch_size=batch_size,
    )

    # create test data
    test_generator = test_datagen.flow_from_directory(
    'test/test',
    target_size=(150, 150),
    batch_size=batch_size,
    )
    print (train_generator.class_indices)
    model.fit_generator(train_generator, epochs=30, validation_data=test_generator, steps_per_epoch=150/batch_size, validation_steps=1)
    model.save('dor_cnn.h5')

    运行结果:测试下来还是很准确的,优化下模型就可以达到可用级别

  • 相关阅读:
    spring多数据源配置
    spring+myBatis 配置多数据源,切换数据源
    Maven项目引入log4j的详细配置
    基于Https协议返回Jason字符串
    Http协议入门、响应与请求行、HttpServletRequest对象的使用、请求参数获取和编码问题
    java http post/get 服务端和客户端实现json传输
    java实现一个简单的Web服务器
    设计模式系列
    Nginx系列
    Linux系列
  • 原文地址:https://www.cnblogs.com/mlwork/p/12468547.html
Copyright © 2011-2022 走看看