zoukankan      html  css  js  c++  java
  • 相似图像搜索从训练到服务全过程

    最近完成了一个以图搜图的项目,项目总共用时三个多月。记录一下项目中用到机器学习的地方,以及各种踩过的坑。总的来说,项目分为一下几个部分:

     一、训练目标函数 

    1、    设定基础模型

    2、    添加新层

    3、    冻结 base 层

    4、    编译模型

    5、    训练

    6、    保存模型

    二、特征提取

    三、创建索引

    四、构建服务

    1、flask 开发 

    2、Gunicorn 异步,增加服务稳健性

    3、Supervisor 部署监控服务

    五、总结  

    一、训练目标函数

    项目是在预训练模型 vgg16 的基础上进行微调(fine_tune),并将特征的维度从原先的 2048 维降为 1024 维度。

    模型的微调又分为以下几个步骤:

    1、设定基础模型

    本次采用预训练的 VGG16基础模型,利用其 bottleneck 特征

     # 设定基础模型

    base_model = VGG16(weights='./model/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5', include_top=False)

     #指定权重路径

    # include_top= False 不加载三层全连接层

    2、添加新层

    将自己要目标图片,简单分类,统计类别(在训练模型时需要指定类别)

    # 添加新层

     

    def add_new_last_layer(base_model, nb_classes):
    
        '''
        添加最后的层
        :param base_model: 预训练模型
        :param nb_classes: 分类数量
        :return: 新的 model
        '''
        x = base_model.output
        x = GlobalAveragePooling2D()(x)
        x = Dense(128, activation='relu')(x) #输出的特征维度 88
        predictions = Dense(nb_classes, activation='softmax')(x)
        model = Model(input=base_model.input, output=predictions)
        return model

    3、冻结 base 层

    以前的参数可以使用预训练好的参数,不需要重新训练,所以需要冻结,不让其改变。

     

    def freeze_base_layer(model, base_model):
    
            for layer in base_model.layers:
    
            layer.trainable = False

     

     4、编译模型

    model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics= ['accuracy'])
    
    # optimizer: 优化器
    
    # loss: 损失函数,多类的对数损失需要将分类标签转换为(将标签转化为形如(nb_samples, nb_classes)的二值序列)
    
    # metrics: 列表,包含评估模型在训练和测试时的网络性能的指标准备训练数据。

    5、训练

    #数据准备
    IM_WIDTH, IM_HEIGHT = 224,224
    train_dir = './refine_img_data/train'
    val_dir = './refine_img_data/test'
    nb_classes = 5
    np_epoch = 3
    batch_size = 16
    nb_train_samples = get_nb_files(train_dir)
    nb_classes = len(glob.glob(train_dir + '/*'))
    nb_val_samples = get_nb_files(val_dir)
    
    # 根据现有数据,设置新数据生成参数
    train_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
    )
    
    test_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True
    )
    
    # 从文件夹获取数据
    train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(IM_WIDTH, IM_HEIGHT),
    batch_size=batch_size,
    class_mode='categorical'
    )
    
    validation_generator = test_datagen.flow_from_directory(
    val_dir,
    target_size=(IM_WIDTH, IM_HEIGHT),
    batch_size=batch_size,
    class_mode='categorical'
    )
    
    # 训练
    history_t1 = model.fit_generator(
    train_generator,
    epochs=1,
    steps_per_epoch=10,
    validation_data=validation_generator,
    validation_steps=10,
    class_weight='auto'
    )

    6、保存模型

    将模型保存到指定路径一般保存为 .h5 格式

     model.save('/model/test_model.h5')

         

    二、特征提取

    加载我们训练好的模型,根据需要,取指定层的特征。

    # 可用 model.summary() 查看模型结构
    
    #根据模型提取图片特征
    
    target_size = (224,224)
    
    def my_feature(mod, path):
        img = image.load_img(path,target_size=target_size)
        img = image.img_to_array(img)
        img = np.expand_dims(img, axis=0)
        img = preprocess_input(img)
        return mod.predict(img)
    
     
    
    # 创建模型,获取指定层特征
    model_path = './model/my_model.h5'
    base_model = load_model(model_path)
    model = Model(inputs=base_model.input, outputs=base_model.get_layer('dense_1').output)
    
     
    
    # 提取特征
    img_path = './my_img/bus.jpg'
    feat = my_feature(model,img_path) # shape 为 (1,128)
    print(feat)
    print(feat.shape)
    
    #注意, 当需要提取的图片特征数量较大,比如千万以上,需要的时间是比较长的,这时我们可以采用多核与批处理来进行 (python 由于 GIL 的问题对多线程不友好)。
    def pre_processs_image(path):
        if path is not None and os.path.exists(path) and len(path) > 10:
          try:
              img = cv2.imread(path, cv2.IMREAD_COLOR)
              img = cv2.resize(img, (224, 224))
              img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
              img = img.transpose(2, 0, 1)
              return [material_id,img, flag]
          except Exception as err:
              traceback.print_exc()
              return None
        else:
        logging.error('could not find path: ' + path)
        return None
    
     
    
    #cpu 部分,调用多核处理函数,指定核数为 20
    with ProcessPoolExecutor(max_workers=20) as executor:
    feat_paras = list(executor.map(pre_processs_image,, material_batch))
    
    
    # GPU 部分采用批处理
    # TODO

     

    三、创建索引

    此处我们使用 Facebook 开源的近邻索引框架 faiss 。

     
    
    # create index
    d = 128
    nlist = 100 # 切分数量
    nprobe = 8 # 每次查找分片数量
    quantizer_img = faiss.IndexFlatL2(d) #根据欧式距离创建索引
    
     
    image_index = None
    model_index = None
    
    if image_feat_array is not None and len(img_feat_list) > 100:
      image_index = faiss.IndexIVFFlat(quantizer_img, d, nlist, faiss.METRIC_L2)
      image_index.train(image_feat_array)
      image_index.add_with_ids(image_feat_array,image_id_array)
      image_index.nprobe = nprobe
      image_index.dont_dealloc_me = quantizer_img
    
    # 保存当前索引到指定路径
    faiss.write_index(img_index,path)
    
    # 测试当前索引
    temp_feat = img_feat_list[1]
    res_2 = image_index.search(temp_feat, k=5)
    logging.info('image search result is:' + str(res_2))

     

    四、构建服务

    采用Flask 框架, gunicorn为 wsgi 容器。supervisor 管理进程。

    1、flask 开发

    参考文档 http://docs.jinkan.org/docs/flask/quickstart.html#a-minimal-application

    2、Gunicorn 异步,增加服务稳健性

    基础语法:

    Gunicorn –w process_num –b ip:port –k 'gevent' fileName:app

    # 注意:此处不选择 –k 'gevent' 则为同步运行

    同步部署:

    gunicorn -b 0.0.0.0:9090 my_service:app

    异步部署:

    gunicorn -b 0.0.0.0:9090 -k gevent my_service:app

    用了 Gunicorn 来部署应用后, 对比 flask , qps 提升了一倍。原 flask 框架中由于我的接口中 request 了其他的接口,线程在此处会阻塞,导致程序非常容易假死。改用后,稳定又了极大的提升。

     

    3、Supervisor 部署监控服务

    可参考以下文档 https://www.cnblogs.com/gjack/p/8076419.html

     

    五、总结

    项目到这个地方,基本的服务框架已经有了。许多地方只说了大体思路,但是结构是完整。文中的许多用了许多方法工具,如 gunicorn 的异步等, 但是原理却不甚了解,还需要花功夫去学习。由于上线压力大,时间紧,许多地方来不及仔细琢磨,肯定有不少纰漏,后面再查漏补缺吧。

  • 相关阅读:
    C文件读写函数介绍(转)
    字节存储排序:大端和小端的判别及转换
    vc++上的MFC的对象序列化和反序列化
    unicode下各种类型转换,CString,string,char*,int,char[]
    CString与std::string unicode下相互转化
    VS2010每次编译都重新编译整个工程的解决方案
    Windows下用C语言获取进程cpu使用率,内存使用,IO情况
    hadoop 安装
    python---pyc pyo文件详解
    C 高级编程 2 内存管理
  • 原文地址:https://www.cnblogs.com/yaolin1228/p/9557588.html
Copyright © 2011-2022 走看看