zoukankan      html  css  js  c++  java
  • 用Tensorflow和FastAPI构建图像分类API

    作者|Aniket Maurya
    编译|VK
    来源|Towards Datas Science

    这个博客的源代码可以从https://github.com/aniketmaurya/tensorflow-web-app-starter-pack获得

    让我们从一个简单的helloworld示例开始

    首先,我们导入FastAPI类并创建一个对象应用程序。这个类有一些有用的参数,比如我们可以传递swaggerui的标题和描述。

    from fastapi import FastAPI
    app = FastAPI(title='Hello world')
    

    我们定义一个函数并用@app.get. 这意味着我们的API/index支持GET方法。这里定义的函数是异步的,FastAPI通过为普通的def函数创建线程池来自动处理异步和不使用异步方法,并且它为异步函数使用异步事件循环。

    @app.get('/index')
    async def hello_world():
        return "hello world"
    

    图像识别API

    我们将创建一个API来对图像进行分类,我们将其命名为predict/image。我们将使用Tensorflow来创建图像分类模型。

    Tensorflow图像分类教程:https://aniketmaurya.ml/blog/tensorflow/deep learning/2019/05/12/image-classification-with-tf2.html

    我们创建了一个函数load_model,它将返回一个带有预训练权重的MobileNet CNN模型,即它已经被训练为对1000个不同类别的图像进行分类。

    import tensorflow as tf
    
    def load_model():
        model = tf.keras.applications.MobileNetV2(weights="imagenet")
        print("Model loaded")
        return model
        
    model = load_model()
    

    我们定义了一个predict函数,它将接受图像并返回预测。我们将图像大小调整为224x224,并将像素值规格化为[-1,1]。

    from tensorflow.keras.applications.imagenet_utils 
    import decode_predictions
    

    decode_predictions用于解码预测对象的类名。这里我们将返回前2个可能的类。

    def predict(image: Image.Image):
    
        image = np.asarray(image.resize((224, 224)))[..., :3]
        image = np.expand_dims(image, 0)
        image = image / 127.5 - 1.0
        
        result = decode_predictions(model.predict(image), 2)[0]
        
        response = []
        
        for i, res in enumerate(result):
            resp = {}
            resp["class"] = res[1]
            resp["confidence"] = f"{res[2]*100:0.2f} %"
            
            response.append(resp)
            
        return response
    

    现在我们将创建一个支持文件上传的API/predict/image。我们将过滤文件扩展名以仅支持jpg、jpeg和png格式的图像。

    我们将使用Pillow加载上传的图像。

    def read_imagefile(file) -> Image.Image:
        image = Image.open(BytesIO(file))
        return image
        
    @app.post("/predict/image")
    async def predict_api(file: UploadFile = File(...)):
        extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
        if not extension:
            return "Image must be jpg or png format!"
        image = read_imagefile(await file.read())
        prediction = predict(image)
        
        return prediction
    

    最终代码

    import uvicorn
    from fastapi import FastAPI, File, UploadFile
    
    from application.components import predict, read_imagefile
    
    app = FastAPI()
    
    @app.post("/predict/image")
    async def predict_api(file: UploadFile = File(...)):
        extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png")
        if not extension:
            return "Image must be jpg or png format!"
        image = read_imagefile(await file.read())
        prediction = predict(image)
        
        return prediction
        
    @app.post("/api/covid-symptom-check")
    def check_risk(symptom: Symptom):
        return symptom_check.get_risk_level(symptom)
        
    if __name__ == "__main__":
        uvicorn.run(app, debug=True)
    

    FastAPI文档是了解框架核心概念的最佳场所:https://fastapi.tiangolo.com/

    希望你喜欢这篇文章。

    原文链接:https://towardsdatascience.com/image-classification-api-with-tensorflow-and-fastapi-fc85dc6d39e8

    欢迎关注磐创AI博客站:
    http://panchuang.net/

    sklearn机器学习中文官方文档:
    http://sklearn123.com/

    欢迎关注磐创博客资源汇总站:
    http://docs.panchuang.net/

  • 相关阅读:
    setInterval和setTimeOut方法—— 定时刷新
    json
    开发者必备的火狐插件
    C#泛型类和集合类的方法
    jQuery几种常用方法
    SQL语句优化技术分析
    索引的优点和缺点
    Repeater使用技巧
    jQuery 表格插件
    利用WebRequest来实现模拟浏览器通过Post方式向服务器提交数据
  • 原文地址:https://www.cnblogs.com/panchuangai/p/13849852.html
Copyright © 2011-2022 走看看