zoukankan      html  css  js  c++  java
  • 机器学习进阶-案例实战-停车场车位识别-keras预测是否停车站有车

    import numpy
    import os
    
    from keras import applications
    from keras.preprocessing.image import ImageDataGenerator
    from keras import optimizers
    from keras.models import Sequential, Model
    from keras.layers import Dropout, Flatten, Dense, GlobalAveragePooling2D
    from keras import backend as k
    from keras.callbacks import ModelCheckpoint, LearningRateScheduler, TensorBoard, EarlyStopping
    from keras.models import Sequential
    from keras.layers.normalization import BatchNormalization
    from keras.layers.convolutional import Conv2D
    from keras.layers.convolutional import MaxPooling2D
    from keras.initializers import TruncatedNormal
    from keras.layers.core import Activation
    from keras.layers.core import Flatten
    from keras.layers.core import Dropout
    from keras.layers.core import Dense
    
    
    files_train = 0
    files_validation = 0
    
    cwd = os.getcwd()
    folder = 'train_data/train'
    for sub_folder in os.listdir(folder):
        path, dirs, files = next(os.walk(os.path.join(folder,sub_folder)))
        files_train += len(files)
    
    
    folder = 'train_data/test'
    for sub_folder in os.listdir(folder):
        path, dirs, files = next(os.walk(os.path.join(folder,sub_folder)))
        files_validation += len(files)
    
    print(files_train,files_validation)
    
    img_width, img_height = 48, 48
    train_data_dir = "train_data/train"
    validation_data_dir = "train_data/test"
    nb_train_samples = files_train
    nb_validation_samples = files_validation
    batch_size = 32
    epochs = 15
    num_classes = 2
    
    model = applications.VGG16(weights='imagenet', include_top=False, input_shape = (img_width, img_height, 3))
    
    
    for layer in model.layers[:10]:
        layer.trainable = False
    
    
    x = model.output
    x = Flatten()(x)
    predictions = Dense(num_classes, activation="softmax")(x)
    
    
    model_final = Model(input = model.input, output = predictions)
    
    
    model_final.compile(loss = "categorical_crossentropy", 
                        optimizer = optimizers.SGD(lr=0.0001, momentum=0.9), 
                        metrics=["accuracy"]) 
    
    
    train_datagen = ImageDataGenerator(
    rescale = 1./255,
    horizontal_flip = True,
    fill_mode = "nearest",
    zoom_range = 0.1,
    width_shift_range = 0.1,
    height_shift_range=0.1,
    rotation_range=5)
    
    test_datagen = ImageDataGenerator(
    rescale = 1./255,
    horizontal_flip = True,
    fill_mode = "nearest",
    zoom_range = 0.1,
    width_shift_range = 0.1,
    height_shift_range=0.1,
    rotation_range=5)
    
    train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size = (img_height, img_width),
    batch_size = batch_size,
    class_mode = "categorical")
    
    validation_generator = test_datagen.flow_from_directory(
    validation_data_dir,
    target_size = (img_height, img_width),
    class_mode = "categorical")
    
    checkpoint = ModelCheckpoint("car1.h5", monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=False, mode='auto', period=1)
    early = EarlyStopping(monitor='val_acc', min_delta=0, patience=10, verbose=1, mode='auto')
    
    
    
    
    history_object = model_final.fit_generator(
    train_generator,
    samples_per_epoch = nb_train_samples,
    epochs = epochs,
    validation_data = validation_generator,
    nb_val_samples = nb_validation_samples,
    callbacks = [checkpoint, early])
  • 相关阅读:
    webpack4.0在项目中的安装配置
    Java调用开源GDAL解析dxf成shp,再调用开源GeoTools解析shp文件
    VUE-CLI3.0组件封装打包使用
    鼠标光标在input框,单击回车键后防止页面刷新的问题
    MapBox GL加载天地图以及加载导航控件
    web前端监控视频的展示
    css外部字体库文件的引用
    IIS上部署的程序,PLSQL能连上数据库,系统登录报错
    部署在IIS上的程序,可以找到文件夹,能看到文件却报404
    继承
  • 原文地址:https://www.cnblogs.com/my-love-is-python/p/10435794.html
Copyright © 2011-2022 走看看