zoukankan      html  css  js  c++  java
  • python代码打包为http服务端接口 (aiohttp)

    目录

    一、需求
    二、服务器端
    三、客户端


    一、需求

    python端代码接受一个待处理得模型文件路径,对模型进行预测并得到相应结果,由于无法直接将python转换为C++,这里采用aiohttp库将python打包打包成http服务端接口,支持客户端传入文件路径、将文件base64编码以及上传文件三种方式进行请求,便于其他语言调用。并编写python客户端代码进行测试,也可以利用postman工具进行测试。aiohttp是一个异步的库,具体的介绍可以参照官网,里面介绍得很详细。

    二、服务器端

    这里只对"application/json"和"multipart/form-data"两种类型进行处理。

    from aiohttp import web
    from inference_class import InferenceClass
    import json
    import asyncio
    import os
    import base64
    
    
    def image_from_base64(base64_utf8):
        decode_data = base64.decodebytes(base64_utf8.encode('utf-8'))
        return decode_data
    
    
    class MeshWebServer(object):
        def __init__(self, max_request=1, cache_dir=None):
            self._app = web.Application()
            self._engine = InferenceClass()
            self._concurrency = asyncio.BoundedSemaphore(max_request)
            self._lock = asyncio.Lock()
            if cache_dir is not None:
                self._cache_dir = cache_dir
            else:
                self._cache_dir = os.path.join(os.getcwd(), 'cache')
    
        def run(self, port):
            self._app.add_routes([
                web.post('/mesh/recognize', self.__on_recognize)
            ])
            web.run_app(self._app, port=port)
    
        async def __on_recognize(self, request):
            obj_data, filename, file_path = "", "", ""
            content_type = request.content_type
            # print("type", content_type)
            if content_type == "application/json":
                try:
                    data = await request.json()
                    if 'obj' in data:
                        obj_data = image_from_base64(data['obj'])
                    if 'filename' in data:
                        filename = data['filename']
                    else:
                        filename = "temp.obj"
                    if "file_path" in data:
                        file_path = data["file_path"]
                        filename = os.path.basename(file_path)
    
                    if os.path.isfile(file_path):
                        file = file_path
                    else:
                        if obj_data != "":
                            os.makedirs(self._cache_dir, exist_ok=True)
                            del_file_list = os.listdir(self._cache_dir)  # 存在文件的话先清空
                            for f in del_file_list:
                                file_path = os.path.join(self._cache_dir, f)
                                if os.path.isfile(file_path):
                                    os.remove(file_path)
    
                            file = os.path.join(self._cache_dir, filename)
                            with open(file, "wb") as f:
                                f.write(obj_data)
                        else:
                            file = ""
                    async with self._lock:
                        predict_class = self._engine.inference(file)
                        respond = dict(text=predict_class[0][1][-1], returnCode="Successed!", filename=filename)
    
                except Exception as e:
                    respond = dict(text='', returnCode="Failed", filename=filename, returnMsg=repr(e))
            elif content_type == "multipart/form-data":
                try:
                    print("headers: ", request.headers)
    
                    reader = await request.multipart()
                    field = await reader.next()
                    filename = field.filename if field.filename else "temp.obj"
                    size = 0
                    os.makedirs(self._cache_dir, exist_ok=True)
                    del_file_list = os.listdir(self._cache_dir)  # 存在文件的话先清空
                    for f in del_file_list:
                        file_path = os.path.join(self._cache_dir, f)
                        if os.path.isfile(file_path):
                            os.remove(file_path)
                    file = os.path.join(self._cache_dir, filename)
                    with open(file, 'wb') as f:
                        while True:
                            chunk = await field.read_chunk()  # 默认是8192个字节。
                            if not chunk:
                                break
                            size += len(chunk)
                            f.write(chunk)
    
                    # # ----小文件----
                    # data = await request.post()
                    # file_data = data["file"]
                    # file = file_data.file
                    # filename = file_data.filename
                    # content = file.read()
                    #
                    # os.makedirs(self._cache_dir, exist_ok=True)
                    # del_file_list = os.listdir(self._cache_dir)  # 存在文件的话先清空
                    # for f in del_file_list:
                    #     file_path = os.path.join(self._cache_dir, f)
                    #     if os.path.isfile(file_path):
                    #         os.remove(file_path)
                    #
                    # file = os.path.join(self._cache_dir, filename)
                    # with open(file, "wb") as f:
                    #     f.write(content)
    
                    async with self._lock:
                        predict_class = self._engine.inference(file)
                        respond = dict(text=predict_class[0][1][-1], returnCode="Successed!", filename=filename)
    
                except Exception as e:
                    print(e)
                    respond = dict(text='', returnCode="Failed", filename=filename, returnMsg=repr(e))
            # elif content_type == "application/octet-stream":
            #     print("Enter octet, headers: ", request.headers)
            #     data = await request.post()
            #     respond = dict(text="octet-stream", returnCode="Successed!", filename=filename)
            #     print("data", data, dir(data), data.values)
            else:
                respond = dict(text="Unknown content type, just support application/json and multipart/form-data",
                               returnCode="Failed!", filename=filename)
            print("---** predict is {} **---".format(respond["returnCode"]))
            return web.json_response(json.dumps(respond))
    

    三、客户端

    3.1 "application/json"

    可以通过json方式传递文件名或者base64编码

    import aiohttp
    import asyncio
    import base64
    import os
    import json
    import time
    
    
    def base64_from_filename(filename):
        with open(filename, "rb") as file_binary:
            data = file_binary.read()
            encoded = base64.b64encode(data)
            encoded_utf8 = encoded.decode('utf-8')
        return encoded_utf8
    
    
    async def do_recognize(web_ip, web_port, file_path, obj_data=None):
        try:
            codename = os.path.basename(file_path)
            if web_ip == "127.0.0.1":
                request = dict(file_path=file_path, filename=codename)
            else:
                request = dict(obj=obj_data, filename=codename)
            print('filename: {}'.format(file_path))
            async with aiohttp.ClientSession() as session:
                async with session.post(url='http://{}:{}/mesh/recognize'.format(web_ip, web_port),
                                        data=json.dumps(request),
                                        headers={'Content-Type': 'application/json; charset=utf-8'}) as resp:
                    respond = await resp.text()
                    respond = respond.replace('\"', '"')
                    respond = respond[1:-1]   # remove " at the begin and end
                    result = json.loads(respond)
                    data = result.get("text")
                    status = result.get("returnCode")
                    print("predict {}, res: {}".format(status, data))
        except:
            print("do_recognize Error {} 
    ".format(filename))
    
    
    def run_test(web_ip, web_port, filename, loop):
        try:
            if web_ip == "127.0.0.1":
                loop.run_until_complete(do_recognize(web_ip, web_port, filename))
            else:
                b64_str = base64_from_filename(filename)
                loop.run_until_complete(do_recognize(web_ip, web_port, filename, b64_str))
        except:
           print("run_test Error")
    
    
    if __name__ == '__main__':
        file_dir = "E:/code/test_models/"
        filenames = os.listdir(file_dir)
        files = [os.path.join(file_dir, filename) for filename in filenames]
        ip = "192.168.107.118"  # "127.0.0.1" 192.168.107.118
        port = 8000
        start = time.time()
        try:
            event_loop = asyncio.get_event_loop()
            tasks = [run_test(ip, port, filename, event_loop) for filename in files]
            event_loop.close()
        except:
            print("__main__ Error")
        end = time.time()
        print("run time is : {}s".format(end - start))
    
    
    
    3.2 "multipart/form-data"

    直接上传文件

    import aiohttp
    from aiohttp import formdata
    import asyncio
    import os
    import time
    import json
    
    
    async def do_recognize(web_ip, web_port, file_path):
        try:
            codename = os.path.basename(file_path)
            print('filename: {}'.format(file_path))
            # file_data = {"file": open(file_path, "rb")}
            file_data = formdata.FormData()
            file_data.add_field('file',
                           open(file_path, 'rb'),
                           # content_type="multipart/form-data; boundary=--dd7db5e4c3bd4d5187bb978aef4d85b1",
                           filename=codename)
            async with aiohttp.ClientSession() as session:
                async with session.post(
                        url='http://{}:{}/mesh/recognize'.format(web_ip, web_port),
                        data=file_data,
                        # headers={'Content-Type': 'multipart/form-data; boundary=--dd7db5e4c3bd4d5187bb978aef4d85b1'}
                ) as resp:
                    respond = await resp.text()
                    respond = respond.replace('\"', '"')
                    respond = respond[1:-1]   # remove " at the begin and end
                    result = json.loads(respond)
                    data = result.get("text")
                    status = result.get("returnCode")
                    print("predict {}, res: {}".format(status, data))
        except:
            print("do_recognize Error {} 
    ".format(file_path))
    
    
    def run_test(web_ip, web_port, filename, loop):
        try:
            loop.run_until_complete(do_recognize(web_ip, web_port, filename))
        except:
           print ("run_test Error")
    
    
    if __name__ == '__main__':
        file_dir = "E:/code/test_models/"
        filenames = os.listdir(file_dir)
        files = [os.path.join(file_dir, filename) for filename in filenames]
        ip = "192.168.107.118"  # "127.0.0.1"
        port = 8000
        start = time.time()
        try:
            event_loop = asyncio.get_event_loop()
            tasks = [run_test(ip, port, filename, event_loop) for filename in files]
            event_loop.close()
        except:
            print("__main__ Error")
        end = time.time()
        print("run time is : {}s".format(end - start))
    
    
    

    参考链接

    https://docs.aiohttp.org/en/stable/
    https://blog.csdn.net/weixin_39643613/article/details/109171090
    https://docs.aiohttp.org/en/stable/web.html

  • 相关阅读:
    归并排序
    希尔排序
    字符串操作
    引用
    直接插入排序
    变量赋值
    C#中关于公共类的使用
    关于SQL中Between语句查询日期的问题
    用户控件 与 重写控件 的区别
    什么是命名空间,为什么要使用命名空间?
  • 原文地址:https://www.cnblogs.com/xiaxuexiaoab/p/14522823.html
Copyright © 2011-2022 走看看