zoukankan      html  css  js  c++  java
  • keras使用多进程

    最近在工作中有一个需求:用训练好的模型将数据库中所有数据得出预测结果,并保存到另一张表上。数据库中的数据是一篇篇文章,我训练好的模型是对其中的四个段落分别分类,即我有四个模型,拿到文本后需要提取出这四个段落,并用对应模型分别预测这四个段落的类别,然后存入数据库中。我是用keras训练的模型,backend为tensorflow,因为数据量比较大,自然想到用多进程。在Windows上运行一点问题没有,但是在Linux服务器上运行时发现每次都停在model.predict上不动了。

    模型使用时大致如下:

    # -*- coding: utf-8 -*-
    import jieba
    import numpy as np
    import keras
    import tensorflow as tf
    from keras.preprocessing import sequence
    from keras.models import load_model
    from config import Config
    import json
    
    
    config_file = 'data/config.ini'
    model_path = Config(config_file).get_value_str('cnn', 'model_path')
    graph = tf.Graph()
    with graph.as_default():
        session = tf.Session()
        with session.as_default():
            model = load_model(model_path)
    
    graph_var = graph
    session_var = session
    
    
    def sentence_process(sentence):
        with open('data/words.json', encoding='utf-8') as f:
            words_json = json.load(f)
        words = words_json['words']
        word_to_id = words_json['word_to_id']
        max_length = words_json['max_length']
        segs = jieba.lcut(sentence)
        segs = filter(lambda x: len(x) >= 1, segs)
        segs = [x for x in segs if x]
        vector = []
        for seg in segs:
            if seg in words:
                vector.append(word_to_id[seg])
            else:
                vector.append(4999)
        return vector, max_length
    
    
    def predict(sentence):
        vector, max_length = sentence_process(sentence)
        vector_np = np.array([vector])
        x_vector = sequence.pad_sequences(vector_np, max_length)
        with graph_var.as_default():
            with session_var.as_default():
                y = model.predict_proba(x_vector)
                if y[0][1] > 0.5:
                    predict = 1
                else:
                    predict = 0
        return predict
    View Code

    多进程使用大致如下: 

    from multiprocessing import Pool
    from classifaction.classify1 import predict1
    from classifaction.classify2 import predict2
    from classifaction.classify3 import predict3
    from classifaction.classify4 import predict4
    
    
    def main():
        '''
        get texts
        '''
        pool = Pool(processes=4, maxtasksperchild=1)
        pool.map(save_to_database, texts)
        pool.close()
        pool.join()
    
    
    def save_to_database(texts):
        text1, text2, text3, text4 = texts[0], texts[1], texts[2], texts[3]
        label1 = predict1(text1)
        label2 = predict2(text2)
        label3 = predict3(text3)
        label4 = predict4(text4)
    
    
    if __name__ == '__main__':
        main()
    View Code

    问题 1

    在Linux服务器上运行时发现所有进程都停在model.predict上不动了。而在Windows下运行良好

    方法

    Google后发现很多遇到这个问题,也终于找到一个方法。可以看一下链接:

    https://github.com/keras-team/keras/issues/9964

    有一个方法是

    As of TF 1.10, the library seems to be somewhat forkable. So you will have to test what you can do.
    
    Also, something you can try is:
    multiprocessing.set_start_method('spawn', force=True) if you're on UNIX and using Python3.

    即在使用multiprocessing之前先设置一下。

    python多进程内存复制

    python对于多进程中使用的是copy on write机制,python 使用multiprocessing来创建多进程时,无论数据是否不会被更改,子进程都会复制父进程的状态(内存空间数据等)。所以如果主进程耗的资源较多时,不小心就会造成不必要的大量的内存复制,从而可能导致内存爆满的情况。

    进程的启动有spawn、fork、forkserver三种方式

    spawn:调用该方法,父进程会启动一个新的python进程,子进程只会继承运行进程对象run()方法所需的那些资源。特别地,子进程不会继承父进程中不必要的文件描述符和句柄。与使用forkforkserver相比,使用此方法启动进程相当慢。

               Available on Unix and Windows. The default on Windows.

    fork:父进程使用os.fork()来fork Python解释器。子进程在开始时实际上与父进程相同,父进程的所有资源都由子进程继承。请注意,安全创建多线程进程尚存在一定的问题。

              Available on Unix only. The default on Unix.

    forkserver:当程序启动并选择forkserverstart方法时,将启动服务器进程。从那时起,每当需要一个新进程时,父进程就会连接到服务器并请求它fork一个新进程。 fork服务器进程是单线程的,因此使用os.fork()是安全的。没有不必要的资源被继承。

             Available on Unix platforms which support passing file descriptors over Unix pipes.

    要选择以上某一种start方法,请在主模块中使用multiprocessing.set_start_method()。并且multiprocessing.set_start_method()在一个程序中仅仅能使用一次。

     由上可见,Windows默认使用spawn方法,和Unix类系统如Linux和Mac默认使用的是fork方法,这就解析了为什么在Windows上可以运行,而在Linux上不能运行的原因。

    在Linux服务器上运行时更改代码如下:

    import multiprocessing
    from multiprocessing import Pool
    from classifaction.classify1 import predict1
    from classifaction.classify2 import predict2
    from classifaction.classify3 import predict3
    from classifaction.classify4 import predict4
    
    
    def main():
        '''
        get texts
        '''
        pool = Pool(processes=4, maxtasksperchild=1)
        multiprocessing.set_start_method('spawn', force=True)
        pool.map(save_to_database, texts)
        pool.close()
        pool.join()
    
    
    def save_to_database(texts):
        text1, text2, text3, text4 = texts[0], texts[1], texts[2], texts[3]
        label1 = predict1(text1)
        label2 = predict2(text2)
        label3 = predict3(text3)
        label4 = predict4(text4)
    
    
    if __name__ == '__main__':
        main()
    View Code

    这样就可以在Unix系统使用多进程了

    问题 2

    如果电脑上配置好了GPU,以tensorflow为backend,调用tensorflow时,默认会启动GPU,这样就没法用多进程了。

    方法

    指定用CPU启动

    只需在程序首部添加以下代码即可

    import os
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  • 相关阅读:
    谈谈系统
    快速发展的Swift是否将淘汰Objective-C?
    XCode环境变量及路径设置
    Windows server2008 搭建ASP接口访问连接oracle数据库全过程记录--备用
    Swift2.0新特性--文章过时重置
    【XCode7+iOS9】http网路连接请求、MKPinAnnotationView自定义图片和BitCode相关错误--备用
    移动App双周版本迭代策略
    ti8168平台的tiler memory
    图像处理之二维码生成-qr
    大数据之网络爬虫-一个简单的多线程爬虫
  • 原文地址:https://www.cnblogs.com/zongfa/p/12193561.html
Copyright © 2011-2022 走看看