zoukankan      html  css  js  c++  java
  • keras中使用预训练模型进行图片分类

    keras中含有多个网络的预训练模型,可以很方便的拿来进行使用。

    安装及使用主要参考官方教程:https://keras.io/zh/applications/   https://keras-cn.readthedocs.io/en/latest/other/application/

    官网上给出了使用 ResNet50 进行 ImageNet 分类的样例

    from keras.applications.resnet50 import ResNet50
    from keras.preprocessing import image
    from keras.applications.resnet50 import preprocess_input, decode_predictions
    import numpy as np
    
    model = ResNet50(weights='imagenet')
    
    img_path = 'elephant.jpg'
    img = image.load_img(img_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    
    preds = model.predict(x)
    # decode the results into a list of tuples (class, description, probability)
    # (one such list for each sample in the batch)
    print('Predicted:', decode_predictions(preds, top=3)[0])
    # Predicted: [(u'n02504013', u'Indian_elephant', 0.82658225), (u'n01871265', u'tusker', 0.1122357), (u'n02504458', u'African_elephant', 0.061040461)]

    那么对于其他的网络,便可以参考此代码

    首先vgg19

    # coding: utf-8
    from keras.applications.vgg19 import VGG19
    from keras.preprocessing import image
    from keras.applications.vgg19 import preprocess_input
    from keras.models import Model
    import numpy as np
    base_model = VGG19(weights='imagenet', include_top=True)
    model = Model(inputs=base_model.input, outputs=base_model.get_layer('fc2').output)
    img_path = '../mdataset/img_test/p2.jpg'
    img = image.load_img(img_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    fc2 = model.predict(x)
    print(fc2.shape)  #(1, 4096)
    View Code

    然后mobilenet

    # coding: utf-8
    from keras.applications.mobilenet import MobileNet
    from keras.preprocessing import image
    from keras.applications.mobilenet import preprocess_input,decode_predictions
    from keras.models import Model
    import numpy as np
    import time
    
    model = MobileNet(weights='imagenet', include_top=True,classes=1000)
    
    start = time.time()
    
    img_path = '../mdataset/img_test/dog.jpg'
    img = image.load_img(img_path, target_size=(224, 224))
    x = image.img_to_array(img)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    
    preds = model.predict(x)
    # decode the results into a list of tuples (class, description, probability)
    # (one such list for each sample in the batch)
    print('Predicted:', decode_predictions(preds, top=15)[0])
    end = time.time()
    
    print('time:
    ')
    print str(end-start)
    View Code

    时间统计时伪统计加载模型的时间,大概需要不到1秒,如果把加载模型的时间算进去,大概3s左右

  • 相关阅读:
    数据库被黑后留下的数据
    cron(CronTrigger)表达式用法
    nodeJS常用的定时执行任务的插件
    css实现隐藏滚动条
    iOS
    iOS
    iOS
    iOS
    iOS
    iOS
  • 原文地址:https://www.cnblogs.com/vactor/p/9813108.html
Copyright © 2011-2022 走看看