zoukankan      html  css  js  c++  java
  • Keras 笔记

    1. 从 meta 模型恢复graph,   修改node  并保存

    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import tensorflow as tf
    from tensorflow.python.framework import graph_util
    
    # create a session
    sess = tf.Session()
    
    src = sys.argv[1]
    dst = sys.argv[2]
    
    # import best model
    saver = tf.train.import_meta_graph('model.ckpt.meta') # graph
    saver.restore(sess, 'model.ckpt') # variables
    
    # get graph definition
    gd = sess.graph.as_graph_def()
    
    # fix batch norm nodes
    for node in gd.node:
      if node.op == 'RefSwitch':
        node.op = 'Switch'
        for index in xrange(len(node.input)):
          if 'moving_' in node.input[index]:
            node.input[index] = node.input[index] + '/read'
      elif node.op == 'AssignSub':
        node.op = 'Sub'
        if 'use_locking' in node.attr: del node.attr['use_locking']
    
    # generate protobuf
    converted_graph_def = graph_util.convert_variables_to_constants(sess, gd, ["logits_set"])
    tf.train.write_graph(converted_graph_def, '/path/to/save/', 'model.pb', as_text=False)

    2. keras  model   转  graph_def

    def loadModel(path_name):
        graph = tf.get_default_graph()
        graph_def = graph.as_graph_def()
        graph_def.ParseFromString(tf.gfile.FastGFile(path_name, 'rb').read())
        tf.import_graph_def(graph_def, name='graph')
        return graph_def

    3.  从 pb模型恢复graph_def   并保存encoder

    import tensorflow as tf
    import sys
    
    name = sys.argv[1]
    path = sys.argv[2]
    
    model = name
    graph = tf.get_default_graph()
    graph_def = graph.as_graph_def()
    graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read())
    tf.import_graph_def(graph_def, name='graph')
    summaryWriter = tf.summary.FileWriter(path, graph)

    4. keras   outnodes 

        sess = K.get_session()
        from tensorflow.python.framework import graph_util,graph_io
        init_graph = sess.graph.as_graph_def()
        main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
          graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)

    5. transform 用法

    Transforms are:
    add_default_attributes
    backport_concatv2
    backport_tensor_array_v3
    flatten_atrous_conv
    fold_batch_norms
    fold_constants
    fold_old_batch_norms
    freeze_requantization_ranges
    fuse_pad_and_conv
    fuse_remote_graph
    fuse_resize_and_conv
    fuse_resize_pad_and_conv
    insert_logging
    merge_duplicate_nodes
    obfuscate_names
    place_remote_graph_arguments
    quantize_nodes
    quantize_weights
    remove_attribute
    remove_control_dependencies
    remove_device
    remove_nodes
    rename_attribute
    rename_op
    rewrite_quantized_stripped_model_for_hexagon
    round_weights
    set_device
    sort_by_execution_order
    sparsify_gather
    strip_unused_nodes

    1. remove_node : 该参数表示删除节点,后面的参数表示删除的节点类型,注意该操作有可能删除一些必须节点

    2. fold_constans: 查找模型中始终为常量的表达式,并用常量替换他们。

    3.fold_batch_norms: 训练过程中使用批量标准化时可以优化在Conv2D或者MatMul之后引入的Mul。需要在fold_cnstans之后使用。(fold_old_batch_norms和他的功能一样,主要是为了兼容老版本)

    4. quantize_weights:将float型数据改为8位计算方式(默认对小于1024的张量不会使用),该方法是压缩模型的主要手段。

    5. strip_unused_nodes:除去输入和输出之间不使用的节点,对于解决移动端内核溢出存在很大的作用。

    6. merge_duplicate_nodes: 合并一些重复的节点

    7: sort_by_execution_order: 对节点进行排序,保证给定点的节点输入始终在该节点之前。

    6. 数据扩充 ImageDataGenerator
     
    test_datagen = ImageDataGenerator(
    preprocessing_function=preprocess_input,
    )
     

    featurewise_center:布尔值,使输入数据集去中心化(均值为0), 按feature执行。
    samplewise_center:布尔值,使输入数据的每个样本均值为0。
    featurewise_std_normalization:布尔值,将输入除以数据集的标准差以完成标准化, 按feature执行。
    samplewise_std_normalization:布尔值,将输入的每个样本除以其自身的标准差。
    zca_whitening:布尔值,对输入数据施加ZCA白化。
    rotation_range:整数,数据提升时图片随机转动的角度。随机选择图片的角度,是一个0~180的度数,取值为0~180。
    width_shift_range:浮点数,图片宽度的某个比例,数据提升时图片随机水平偏移的幅度。
    height_shift_range:浮点数,图片高度的某个比例,数据提升时图片随机竖直偏移的幅度。 
    height_shift_range和width_shift_range是用来指定水平和竖直方向随机移动的程度,这是两个0~1之间的比例。
    shear_range:浮点数,剪切强度(逆时针方向的剪切变换角度)。是用来进行剪切变换的程度。
    zoom_range:浮点数或形如[lower,upper]的列表,随机缩放的幅度,若为浮点数,则相当于[lower,upper] = [1 - zoom_range, 1+zoom_range]。用来进行随机的放大。
    channel_shift_range:浮点数,随机通道偏移的幅度。
    fill_mode:‘constant’,‘nearest’,‘reflect’或‘wrap’之一,当进行变换时超出边界的点将根据本参数给定的方法进行处理
    cval:浮点数或整数,当fill_mode=constant时,指定要向超出边界的点填充的值。
    horizontal_flip:布尔值,进行随机水平翻转。随机的对图片进行水平翻转,这个参数适用于水平翻转不影响图片语义的时候。
    vertical_flip:布尔值,进行随机竖直翻转。

    rescale: 值将在执行其他处理前乘到整个图像上,我们的图像在RGB通道都是0~255的整数,这样的操作可能使图像的值过高或过低,所以我们将这个值定为0~1之间的数。
    preprocessing_function: 将被应用于每个输入的函数。该函数将在任何其他修改之前运行。该函数接受一个参数,为一张图片(秩为3的numpy array),并且输出一个具有相同shape的numpy array
    data_format:字符串,“channel_first”或“channel_last”之一,代表图像的通道维的位置。该参数是Keras 1.x中的image_dim_ordering,“channel_last”对应原本的“tf”,“channel_first”对应原本的“th”。以128x128的RGB图像为例,“channel_first”应将数据组织为(3,128,128),而“channel_last”应将数据组织为(128,128,3)。该参数的默认值是~/.keras/keras.json中设置的值,若从未设置过,则为“channel_last”。

    brightness_range: Tuple or list of two floats. Range for picking
    a brightness shift value from.
  • 相关阅读:
    Delegates in C#
    Continues Integration
    单例模式(Singleton Pattern)
    敏捷开发中编写高质量Java代码
    How to debug your application (http protocol) using Fiddler
    Java EE核心框架实战(1)
    03 Mybatis:01.Mybatis课程介绍及环境搭建&&02.Mybatis入门案例
    04 Spring:01.Spring框架简介&&02.程序间耦合&&03.Spring的 IOC 和 DI&&08.面向切面编程 AOP&&10.Spring中事务控制
    黑马IDEA版javaweb_22MySQL
    第04项目:淘淘商城(SpringMVC+Spring+Mybatis) 的学习实践总结【第二天】
  • 原文地址:https://www.cnblogs.com/luoyinjie/p/10636210.html
Copyright © 2011-2022 走看看