zoukankan      html  css  js  c++  java
  • keras模型量化

    模型量化的本质就是将模型中的参数按照一定的规则 把32位或者64位浮点数 转化位16位浮点数或者8位定点数。这里我用keras和numpy实现了16位和8位的量化,未考虑量化的科学合理性,仅仅是搞清楚量化本质的一次实验。

    量化
    
    
    """
    #coding:utf-8
    __project_ = 'TF2learning'
    __file_name__ = 'quantization'
    __author__ = 'qilibin'
    __time__ = '2021/3/17 9:18'
    __product_name = PyCharm
    """
    import h5py
    import pandas as pd
    import numpy as np
    
    
    '''
    读取原来的只包含权重的H5模型,按层遍历,对每层的每个权重进行16位或8位量化,将量化后的权重数值重新保存在H5文件中
    '''
    
    
    def quantization16bit(old_model_path,new_model_path,bit_num):
        '''
    
        :param old_model_path: 未量化的模型路径  模型是只保存了权重未保存网络结构
        :param new_model_path: 量化过后的模型路径
        :param bit_num: 量化位数
        :return:
        '''
    
    
        f = h5py.File(old_model_path,'r')
        f2 = h5py.File(new_model_path,'w')
        for  layer in f.keys():
            # layer : 层的名称
    
            print (layer)
    
            # # 每层里面的权重名称 有的层没有参数
            # name_of_weight_of_layer = f[layer].attrs['weight_names']
            # # 有的层是没有参数的 比如 relu
            # length = len(name_of_weight_of_layer)
    
            length = len(list(f[layer].keys()))
            if length > 0:
                g1 = f2.create_group(layer)
                g1.attrs["weight_names"] = layer
                g2 = g1.create_group(layer)
    
                for weight in f[layer][layer].keys():
                    print ("wieght name is :" + weight)
                    oldparam = f[layer][layer][weight][:]
                    print ('-----------------------------------------old-----------------------')
                    print (oldparam)
    
                    if type(oldparam) == np.ndarray:
                        if bit_num == 16:
                            newparam = np.float16(oldparam)
                        if bit_num == 8:
                            min_val = np.min(oldparam)
                            max_val = np.max(oldparam)
                            oldparam = np.round((oldparam - min_val) / (max_val - min_val) * 255)
                            newparam = np.uint8(oldparam)
                    else:
                        newparam = oldparam
                    print ('-----------------------------------------new-----------------------')
                    #print (newparam)
                    #f[key][key][weight_name][:] = newparam  在原来模型的基础上修改 行不通
                    if bit_num == 16:
                        d = g2.create_dataset(weight, data=newparam,dtype=np.float16)
                    if bit_num == 8:
                        d = g2.create_dataset(weight, data=newparam, dtype=np.uint8)
    
            else:
                g1 = f2.create_group(layer)
                g1.attrs["weight_names"] = layer
        f.close()
        f2.close()
    old_model_path = './model_0_.h5'
    new_model_path = './new_model.h5'
    quantization16bit(old_model_path,new_model_path,8)
    # print (f['batch_normalization']['batch_normalization']['gamma:0'][:])

    检查量化后的文件

    """
    #coding:utf-8
    __project_ = 'TF2learning'
    __file_name__ = 'readNewMoDel'
    __author__ = 'qilibin'
    __time__ = '2021/3/17 13:27'
    __product_name = PyCharm
    """
    '''
    用来打印量化之后的模型 查看其各个权重的参数
    '''
    import h5py
    
    modelpath = './new_model.h5'
    #modelpath = './model_0_.h5'
    
    
    f = h5py.File(modelpath,'r')
    
    for layer in f.keys():
        # key : 层的名称
    
        print ("layer name is :"+layer)
    
        # 有些层是没有参数的 比如relu
        length = len(list(f[layer].keys()))
        #print (length)
        if length > 0:
            for weight in f[layer][layer].keys():
                print("wieght name is :" + weight)
                param = f[layer][layer][weight][:]
                print(param)
    
    
    f.close()
    
    # print (f['batch_normalization']['batch_normalization']['gamma:0'][:])
  • 相关阅读:
    基于shell脚本比较数字加减乘除 要bc计算器
    基于shell脚本比较数字大小
    备份WordPress
    在CentOS 7 安装没有mysql
    SQL中判断字符串中包含字符的方法
    ASP.NET生成的HTML代码
    win7禁用休眠,献给c盘空间不足的朋友.
    SQLServer2005和2008的分页技术比较[转]
    浏览器兼容性系列--浅谈window.attachEvent
    在ASP.NET 的服务器端控件中有三种关于 ID 的属性
  • 原文地址:https://www.cnblogs.com/cnugis/p/14550041.html
Copyright © 2011-2022 走看看