zoukankan      html  css  js  c++  java
  • 在Cuda上部署量化模型

    在Cuda上部署量化模型

    介绍TVM自动量化。自动量化是TVM中的一种量化方式。将在ImageNet上导入一个GluonCV预先训练的模型到Relay,量化Relay模型,然后执行推理。

    import tvm

    from tvm import te

    from tvm import relay

    import mxnet as mx

    from tvm.contrib.download import download_testdata

    from mxnet import gluon

    import logging

    import os

     

    batch_size = 1

    model_name = "resnet18_v1"

    target = "cuda"

    dev = tvm.device(target)

     

    准备数据集

     

    将演示如何准备用于量化的校准数据集。首先下载ImageNet的验证集,对数据集进行预处理。

    calibration_rec = download_testdata(

        "http://data.mxnet.io.s3-website-us-west-1.amazonaws.com/data/val_256_q90.rec",

        "val_256_q90.rec",

    )

     

    def get_val_data(num_workers=4):

        mean_rgb = [123.68, 116.779, 103.939]

        std_rgb = [58.393, 57.12, 57.375]

     

        def batch_fn(batch):

            return batch.data[0].asnumpy(), batch.label[0].asnumpy()

     

        img_size = 299 if model_name == "inceptionv3" else 224

        val_data = mx.io.ImageRecordIter(

            path_imgrec=calibration_rec,

            preprocess_threads=num_workers,

            shuffle=False,

            batch_size=batch_size,

            resize=256,

            data_shape=(3, img_size, img_size),

            mean_r=mean_rgb[0],

            mean_g=mean_rgb[1],

            mean_b=mean_rgb[2],

            std_r=std_rgb[0],

            std_g=std_rgb[1],

            std_b=std_rgb[2],

        )

        return val_data, batch_fn

    校准数据集应该是一个iterable对象。在Python中将校准数据集定义为生成器对象。仅使用少量样本进行校准。

    calibration_samples = 10

     

    def calibrate_dataset():

        val_data, batch_fn = get_val_data()

        val_data.reset()

        for i, batch in enumerate(val_data):

            if i * batch_size >= calibration_samples:

                break

            data, _ = batch_fn(batch)

            yield {"data": data}

    导入模型

     

    使用Relay MxNet前端,从Gluon模型zoo,导入模型。

    def get_model():

        gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True)

        img_size = 299 if model_name == "inceptionv3" else 224

        data_shape = (batch_size, 3, img_size, img_size)

        mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})

        return mod, params

    量化模型

    在量化中,需要找到每个层的权重和中间特征映射张量的比例。

    对于权重,将根据权重值直接计算比例。支持两种模式:power2和max。这两种模式都首先在权重张量内找到最大值。在power2模式下,最大值向下舍入为2的幂。如果权重和中间特征映射的比例都是二的幂,可以利用移位进行乘法,计算效率更高。在最大模式下,最大值用作scale。如果没有舍入,“最大”模式可能具有更好的精度。当scale不是二的幂时,将使用定点乘法。

    对于中间特征映射,可以通过数据感知量化,找到scale。数据感知量化,将校准数据集作为输入参数。通过最小化激活前与量化后分布间的KL散度,计算标度。可以使用预定义的全局scale。可以节省校准时间。但准确度可能会受到影响。

    def quantize(mod, params, data_aware):

        if data_aware:

            with relay.quantize.qconfig(calibrate_mode="kl_divergence", weight_scale="max"):

                mod = relay.quantize.quantize(mod, params, dataset=calibrate_dataset())

        else:

            with relay.quantize.qconfig(calibrate_mode="global_scale", global_scale=8.0):

                mod = relay.quantize.quantize(mod, params)

        return mod

    运行推理

    创建一个Relay VM,构建和执行模型。

    def run_inference(mod):

        model = relay.create_executor("vm", mod, dev, target).evaluate()

        val_data, batch_fn = get_val_data()

        for i, batch in enumerate(val_data):

            data, label = batch_fn(batch)

            prediction = model(data)

            if i > 10:  # only run inference on a few samples in this tutorial

                break

     

    def main():

        mod, params = get_model()

        mod = quantize(mod, params, data_aware=True)

        run_inference(mod)

     

    if __name__ == "__main__":

        main()

    输出:

    /workspace/python/tvm/relay/build_module.py:333: DeprecationWarning: Please use input parameter mod (tvm.IRModule) instead of deprecated parameter mod (tvm.relay.function.Function)

      DeprecationWarning,

     

     

    参考链接:

    http://tvm.apache.org/docs/how_to/deploy_models/deploy_quantized.html

     

    人工智能芯片与自动驾驶
  • 相关阅读:
    mybatis 插件的原理-责任链和动态代理的体现
    优雅的对象转换解决方案-MapStruct使用进阶(二)
    将博客搬至CSDN
    python headers missing
    Gvim:unable to load python
    gvim keil 快捷跳转至出现错误(警告)行
    stm32 堆溢出
    keil在线烧录突然提示 No target connected #
    cygwin vim can't write .viminfo
    切换用户后,/etc/profile的配置不起效
  • 原文地址:https://www.cnblogs.com/wujianming-110117/p/15497050.html
Copyright © 2011-2022 走看看