zoukankan      html  css  js  c++  java
  • 图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑

    import tensorflow as tf
    import tensorflow.contrib.slim as slim
    
    
    import rawpy
    import numpy as np
    import tensorflow as tf
    import struct
    import glob
    import os
    from PIL import Image
    import time
    
    __sony__ = 0
    __huawei__ = 1
    __blackberry__ = 2
    
    __stage_raw2raw__ = 0
    __stage_raw2rgb__ = 1
    __stage_overall__ = 2
    
    train_prefix = '0'
    valid_prefix = '1'
    test_prefix = '2'
    
    # ============ CONFIGURATION ============
    USE_GPU = False
    if USE_GPU:
        os.environ['CUDA_VISIBLE_DEVICES'] = '2'
    # change this to switch between datasets
    source_id = __sony__
    
    # switch between training stages
    training_stage = __stage_raw2rgb__
    
    # patch size should be set on running
    patch_size = (512, 512)
    #patch_size = (2840, 4248)
    
    # switch between training and validation
    current_prefix = train_prefix
    
    # model saving settings
    max_epoch = 2000
    save_epoch_delay = 1
    model_dir = './result_raw2raw/'
    out_dir = './output_raw2raw/'
    log_dir = './log_raw2raw/'
    learn_rate = 1e-2
    # ============ CONFIGURATION ============
    
    
    if source_id == __blackberry__:
        WHITE_LEVEL = 1023
        BLACK_LEVEL = 64
        HEIGHT = 3024
        WIDTH = 4032
    elif source_id == __sony__:
        WHITE_LEVEL = 16383
        BLACK_LEVEL = 512
        HEIGHT = 2848
        WIDTH = 4256
    elif source_id == __huawei__:
        WHITE_LEVEL = 1023
        BLACK_LEVEL = 64
        HEIGHT = 2976
        WIDTH = 3968
    
    if USE_GPU:
        data_dir = '../see_in_the_dark/dataset/Sony_small/'
    else:
        data_dir = 'D:/data/Sony_small/'
    
    
    # !!!!!! DO NOT TOUCH THIS SETTING !!!!!!
    fixed_size = (128, 128)
    num_of_denoise_filter = 3
    standard_brightness = 0.1
    # !!!!!! DO NOT TOUCH THIS SETTING !!!!!!
    
    
    def has_nan_in_tensor(x):
        return np.sum(x != x) > 0
    
    
    def raw_from_file(path):
        if source_id == __sony__:
            data = rawpy.imread(path)
            raw = data.raw_image_visible.astype(np.float32)
            raw = raw.reshape(2848, 4256)
            # convert from RGBG into standard GRGB format:
            # cut the strips of left and right borders
            h, w = raw.shape[0], raw.shape[1]
            return np.reshape(raw[:, 1:w-1], [h, w-2, 1])
        elif source_id == __huawei__:
            data = rawpy.imread(path)
            raw = data.raw_image_visible.astype(np.float32)
            raw = raw.reshape(2976, 3968)
            # convert from BGRG into standard GRGB format:
            # cut the strips of top and bottom borders
            h, w = raw.shape[0], raw.shape[1]
            return np.reshape(raw[1:h-1, :], [h-2, w, 1])
        elif source_id == __blackberry__:
            data = open(path, 'rb').read()
            data = struct.unpack('H'*int(len(data)/2), data)
            raw = np.float32(data)
            raw = raw.reshape(3024, 4032)
            h, w = raw.shape[0], raw.shape[1]
            return np.reshape(raw, [h, w, 1])
        else:
            assert False
    
    
    def rgb_from_file(path):
        if source_id == __sony__:
            raw = rawpy.imread(path)
            rgb = np.float32(
                raw.postprocess(
                    use_camera_wb=True,
                    half_size=False,
                    no_auto_bright=True,
                    output_bps=16
                )
            ) / 65535.0
            return rgb[:, 1:-1, :]
        elif source_id == __huawei__:
            raw = rawpy.imread(path)
            rgb = np.float32(
                raw.postprocess(
                    use_camera_wb=True,
                    half_size=False,
                    no_auto_bright=True,
                    output_bps=16
                )
            ) / 65535.0
            return rgb[1:-1, :, :]
        else:
            raise NameError('file type [%d] does not support rawpy!' % source_id)
    
    
    def black_level_correction(bayer):
        with tf.name_scope('black_level_corr'):
            r = 1.0/(WHITE_LEVEL-BLACK_LEVEL)
            return tf.nn.relu((bayer - BLACK_LEVEL)*r)
    
    
    def bound(bayer):
        return tf.minimum(tf.maximum(bayer, 0), 1)
    
    
    def bayer_to_rgb(bayer):
        with tf.name_scope('bayer2rgb'):
            filters = np.array([
                [0.0, 1.0, 0.0, 0.0],   # R
                [0.5, 0.0, 0.0, 0.5],   # (G1+G2)/2
                [0.0, 0.0, 1.0, 0.0],   # B
            ]).reshape([1, 3, 2, 2]).transpose([2, 3, 0, 1])
            return tf.nn.conv2d(
                bayer,
                filters,
                strides=(1, 2, 2, 1),
                padding='VALID',
                name='bayer_converter'
            )
    
    
    def demosaic(rgb):
        with tf.name_scope('demosaic'):
            return tf.image.resize_bilinear(rgb, patch_size)
    
    
    def color_correction(rgb, color_matrix):
        with tf.name_scope('color_corr'):
            filters = tf.reshape(color_matrix, [1, 1, 3, 3])
            return tf.nn.conv2d(rgb, filters, (1, 1, 1, 1), 'SAME', name='output')
    
    
    def min_max_normalize(rgb):
        _min = tf.reduce_min(rgb)
        _max = tf.reduce_max(rgb)
        return (rgb - _min + 1e-8)/(_max - _min + 1e-8)
    
    
    def gaussian_norm(rgb):
        _mean = tf.reduce_mean(rgb)
        _vari = tf.sqrt(tf.reduce_mean(tf.square(rgb-_mean)))
        return (rgb-_mean)/_vari
    
    
    # not supported on SNPE, so do it on cpu of mobile phone
    # in case of negative value, normalize it before power operation
    def gamma_correction(rgb, gamma):
        with tf.name_scope('gamma_corr'):
            return tf.pow(min_max_normalize(rgb), gamma)
    
    
    def lrelu(x):
        return tf.maximum(x*0.2, x)
    
    
    def network_raw2raw(inputs):
        with tf.name_scope('raw2raw'):
            net = slim.conv2d(inputs, 32, [3, 3], rate=1, activation_fn=lrelu, weights_initializer=tf.initializers.constant,
                              scope='g_conv1')
            net = slim.conv2d(net, 32, [3, 3], rate=2, activation_fn=lrelu, weights_initializer=tf.initializers.constant,
                              scope='g_conv2')
            net = slim.conv2d(net, 32, [3, 3], rate=4, activation_fn=lrelu, weights_initializer=tf.initializers.constant,
                              scope='g_conv3')
            net = slim.conv2d(net, 32, [3, 3], rate=8, activation_fn=lrelu, weights_initializer=tf.initializers.constant,
                              scope='g_conv4')
            net = slim.conv2d(net, 32, [3, 3], rate=16, activation_fn=lrelu, weights_initializer=tf.initializers.constant,
                              scope='g_conv5')
            net = slim.conv2d(net, 1, [1, 1], rate=1, activation_fn=None, scope='g_conv_last')
        return net
    
    
    def show(rgb, title):
        im = Image.fromarray(np.uint8(rgb * 255))
        im.show(title)
    
    
    def save(rgb, path):
        im = Image.fromarray(np.uint8(rgb * 255))
        im.save(path)
    
    
    def concat(ims):
        return np.concatenate(ims, axis=1)
    
    
    def get_color_matrix_and_gamma(bayer):
        with tf.name_scope('isp_param_gen'):
            with tf.name_scope('common_extractor'):
                channels = tf.layers.conv2d(bayer, 3, kernel_size=3, strides=2, padding='valid')
                activations = tf.nn.tanh(channels)
                channels = tf.layers.conv2d(activations, 5, kernel_size=3, strides=2, padding='valid')
                activations = tf.nn.relu(channels)
            with tf.name_scope('color_matrix'):
                channels_cm = tf.layers.conv2d(activations, 7, kernel_size=3, strides=2, padding='valid')
                activations_cm = tf.nn.tanh(channels_cm)
                channels_cm = tf.layers.conv2d(activations_cm, 5, kernel_size=3, strides=2, padding='valid')
                channels_flat_cm = tf.reshape(
                    channels_cm,
                    [-1, channels_cm.shape[1]*channels_cm.shape[2]*channels_cm.shape[3]])
                color_matrix = tf.reshape(tf.layers.dense(channels_flat_cm, 9), [3, 3])
            with tf.name_scope('gamma'):
                channels_gamma = tf.layers.conv2d(activations, 7, kernel_size=3, strides=2, padding='valid')
                activations_gama = tf.nn.tanh(channels_gamma)
                channels_gamma = tf.layers.conv2d(activations_gama, 5, kernel_size=3, strides=2, padding='valid')
                channels_flat_gamma = tf.reshape(
                    channels_gamma,
                    [-1, channels_gamma.shape[1] * channels_gamma.shape[2] * channels_gamma.shape[3]])
                gamma = tf.reshape(tf.maximum(tf.layers.dense(channels_flat_gamma, 1), 1e-3), [1])
            return color_matrix, gamma
    
    
    def build_isp_process_flow(bayer, color_matrix, gamma):
        with tf.name_scope('isp_flow'):
            return gamma_correction(
                color_correction(
                    demosaic(
                        bayer
                    ), color_matrix
                ), gamma
            )
    
    
    # in form of NHWC
    def color_normalize(rgb):
        return rgb/tf.expand_dims(tf.maximum(tf.reduce_sum(rgb, axis=3), 1e-7), axis=-1)
    
    
    def color_loss(rgb_out, rgb_gt):
        return tf.reduce_mean(tf.abs(color_normalize(rgb_out) - color_normalize(rgb_gt)))
    
    
    # load images from files
    gt_files = glob.glob(data_dir + '/long/' + current_prefix + '*.ARW')
    in_files = [None]*len(gt_files)
    
    train_ids = [None] * len(gt_files)
    gt_raws = [None] * len(train_ids)
    gt_rgbs = [None] * len(train_ids)
    in_raws = [None] * len(train_ids)
    
    # Reorganize the raw files according to their training id
    for i in range(len(gt_files)):
        if USE_GPU:
            train_ids[i] = gt_files[i].split('/')[-1][1:5]
        else:
            train_ids[i] = gt_files[i].split('\')[-1][1:5]
        # for input files, multiple ones may relate to single ground truth file
        in_files[i] = glob.glob(data_dir + '/short/' + current_prefix + train_ids[i] + '*.ARW')
        in_raws[i] = [None]*len(in_files[i])
    
    
    def get_gt_file_by_train_id(tid):
        return gt_files[tid]
    
    
    def get_in_file_by_train_id_file_id(tid, fid):
        return in_files[tid][fid]
    
    
    def get_patch_pair_raw_raw(raw_in, raw_gt):
        h, w = raw_in.shape[0], raw_in.shape[1]
        y, x = np.random.randint(0, h - patch_size[0]), np.random.randint(0, w - patch_size[1])
        return (
            np.expand_dims(raw_in[y:y + patch_size[0], x:x + patch_size[1], :], axis=0),
            np.expand_dims(raw_gt[y:y + patch_size[0], x:x + patch_size[1], :], axis=0)
        )
    
    
    def get_patch_pair_raw_rgb(raw, rgb):
        h, w = raw.shape[0], raw.shape[1]
        y, x = np.random.randint(0, h - patch_size[0]), np.random.randint(0, w - patch_size[1])
        return (
            np.expand_dims(raw[y:y + patch_size[0], x:x + patch_size[1], :], axis=0),
            np.expand_dims(rgb[y:y + patch_size[0], x:x + patch_size[1], :], axis=0)
        )
    
    
    def get_rand_patch_from_file_raw2rgb():
        while True:
            seq = np.random.permutation(len(train_ids))
            for ind in seq:
                if gt_rgbs[ind] is None:
                    # resource not found in cache, load it from disk
                    gt_file = get_gt_file_by_train_id(ind)
                    gt_rgb = rgb_from_file(gt_file)
                fid = np.random.randint(0, len(in_files[ind]))
                if in_raws[ind][fid] is None:
                    in_file = get_in_file_by_train_id_file_id(ind, fid)
                    in_raw = raw_from_file(in_file)
                # cache them when using GPU on linux server since memory is sufficient
                if USE_GPU:
                    gt_rgbs[ind] = gt_rgb
                    in_raws[ind][fid] = in_raw
                yield get_patch_pair_raw_rgb(in_raw, gt_rgb)
    
    
    def get_rand_patch_from_file_raw2raw():
        while True:
            seq = np.random.permutation(len(train_ids))
            for ind in seq:
                if gt_raws[ind] is None:
                    # resource not found in cache, load it from disk
                    gt_file = get_gt_file_by_train_id(ind)
                    gt_raw = raw_from_file(gt_file)
                fid = np.random.randint(0, len(in_files[ind]))
                if in_raws[ind][fid] is None:
                    in_file = get_in_file_by_train_id_file_id(ind, fid)
                    in_raw = raw_from_file(in_file)
                # cache them when using GPU on linux server since memory is sufficient
                if USE_GPU:
                    in_raws[ind][fid] = in_raw
                    gt_raws[ind] = gt_raw
                yield get_patch_pair_raw_rgb(in_raw, gt_raw)
    
    
    # basic nodes
    t_bayer_in = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 1], name='input')
    t_bayer_gt = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 1])
    t_bayer_std = black_level_correction(t_bayer_in)
    t_bayer_gt_std = black_level_correction(t_bayer_gt)
    t_bayer_boosted = network_raw2raw(tf.minimum(300*t_bayer_std, 1.0))
    
    t_half_rgb = bayer_to_rgb(t_bayer_std)
    t_half_rgb_boosted = bayer_to_rgb(bound(t_bayer_boosted))
    t_half_rgb_gt = bayer_to_rgb(t_bayer_gt_std)
    t_half_rgb_resized = tf.image.resize_bilinear(t_half_rgb, fixed_size)
    
    t_rgb_gt = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 3])
    
    # ISP nodes
    t_color_matrix, t_gamma = get_color_matrix_and_gamma(t_half_rgb_resized)
    
    # training raw2raw alone
    # t_err_raw = tf.reduce_mean(tf.abs(t_half_rgb_gt - t_half_rgb_boosted))
    t_err_raw = tf.reduce_mean(tf.abs(gaussian_norm(t_half_rgb_boosted) - gaussian_norm(t_half_rgb_gt)))
    
    # training raw2rgb alone
    t_half_rgb_freeze = tf.stop_gradient(t_half_rgb_boosted)
    t_rgb_freeze = build_isp_process_flow(t_half_rgb_freeze, t_color_matrix, t_gamma)
    # t_err_rgb = tf.reduce_mean(tf.abs(t_rgb_gt - t_rgb_freeze))
    t_err_rgb = color_loss(t_rgb_freeze, t_rgb_gt) + tf.abs(t_gamma[0] - 1.0/2.5)
    # t_err_rgb = color_loss(t_rgb_freeze, t_rgb_gt)
    
    # training overall model
    t_rgb_final = build_isp_process_flow(t_half_rgb_boosted, t_color_matrix, t_gamma)
    # t_err_overall = tf.reduce_mean(tf.abs(t_rgb_gt - t_rgb_final))
    t_err_overall = color_loss(t_rgb_final, t_rgb_gt)
    
    
    def clean_no_grad_vars(vs, gs):
        vs_clear = []
        gs_clear = []
        for i in range(len(gs)):
            if gs[i] is not None:
                vs_clear.append(vs[i])
                gs_clear.append(gs[i])
        return vs_clear, gs_clear
    
    
    def make_var_grad_pairs(vs, gs):
        return [(gs[i], vs[i]) for i in range(len(vs))]
    
    
    def train():
        print('Staged training begins...')
        t_opt = tf.train.GradientDescentOptimizer(learning_rate=learn_rate)
        sess = tf.Session()
    
        t_minimizer_raw2raw = t_opt.minimize(t_err_raw)
        t_minimizer_raw2rgb = t_opt.minimize(t_err_rgb)
        t_minimizer_overall = t_opt.minimize(t_err_overall)
    
        # include = ['g_conv1', 'g_conv2', 'g_conv3', 'g_conv4', 'g_conv5', 'g_conv_last']
        # variables_to_restore = slim.get_variables_to_restore(include=include)
    
        # saver = tf.train.Saver(variables_to_restore)
        saver = tf.train.Saver(tf.global_variables())
        sess.run(tf.global_variables_initializer())
    
        # logger
        if not os.path.exists(log_dir):
            os.mkdir(log_dir)
        logger = tf.summary.FileWriter(log_dir, graph=sess.graph)
        t_sum_raw = tf.summary.scalar('raw2raw_loss', t_err_raw)
        t_sum_rgb = tf.summary.scalar('raw2rgb_loss', t_err_rgb)
        t_sum_all = tf.summary.scalar('overall_loss', t_err_overall)
    
        if not os.path.exists(os.path.join(model_dir, 'checkpoint')):
            if not os.path.exists(model_dir):
                os.mkdir(model_dir)
        else:
            print('Restoring model...')
            model_name_prefix = 'model_checkpoint_path: "'
            with open(os.path.join(model_dir + 'checkpoint')) as ckpt:
                latest_id = ckpt.readline()[len(model_name_prefix):-2]
                saver.restore(sess, os.path.join(model_dir, latest_id))
    
        # bind saver to the full graph instead of a sub-graph
        saver = tf.train.Saver(tf.global_variables())
    
        # first stage: raw to raw training
        if training_stage == __stage_raw2raw__:
            print('Stage I: train to map input raw into ground truth raw')
            patches = get_rand_patch_from_file_raw2raw()
            counter = 0
            t_start = time.clock()
            for raw_in, raw_gt in patches:
                _, err_raw2raw, sum_raw = sess.run(
                    [t_minimizer_raw2raw, t_err_raw, t_sum_raw],
                    feed_dict={
                        t_bayer_in: raw_in,
                        t_bayer_gt: raw_gt
                    }
                )
    
                logger.add_summary(sum_raw, counter)
                epoch = int(counter / len(train_ids))
                print('Epoch# %d Counter# %d  Loss= %.7f' % (epoch, counter, err_raw2raw))
                counter += 1
    
                if counter % 100 is 0:
                    t_stop = time.clock()
                    print('Speed: %.6f' % ((t_stop - t_start) / 100))
                    t_start = t_stop
    
                if counter > max_epoch * len(train_ids):
                    saver.save(sess, model_dir + '/' + str(epoch))
                    print('Training done.')
                    break
                elif counter % (len(train_ids) * save_epoch_delay) is 0:
                    saver.save(sess, model_dir + '/' + str(epoch))
                    print('Model saved.')
        # second stage: raw to rgb training
        if training_stage == __stage_raw2rgb__:
            print('Stage II: train to map generated raw into ground truth rgb')
    
            # gradient clip
            # t_vs = tf.trainable_variables()
            # t_gs = tf.gradients(t_err_rgb, t_vs)
            # t_vs, t_gs = clean_no_grad_vars(t_vs, t_gs)
            # t_var_grad_pairs = make_var_grad_pairs(t_vs, t_gs)
            # t_minimizer_raw2rgb = t_opt.apply_gradients(t_var_grad_pairs)
    
            patches = get_rand_patch_from_file_raw2rgb()
            counter = 0
            t_start = time.clock()
            for raw_in, rgb_gt in patches:
                _, err_raw2rgb, sum_rgb, gamma = sess.run(
                    [t_minimizer_raw2rgb, t_err_rgb, t_sum_rgb, t_gamma],
                    feed_dict={
                        t_bayer_in: raw_in,
                        t_rgb_gt: rgb_gt
                    }
                )
    
                # _, err_raw2rgb, grads, sum_rgb, gamma = sess.run(
                #     [t_minimizer_raw2rgb, t_err_rgb, t_gs, t_sum_rgb, t_gamma],
                #     feed_dict={
                #         t_bayer_in: raw_in,
                #         t_rgb_gt: rgb_gt
                #     }
                # )
    
                logger.add_summary(sum_rgb, counter)
                epoch = int(counter / len(train_ids))
                print('Epoch# %d Counter# %d  Loss= %.7f Gamma=%.6f' % (epoch, counter, err_raw2rgb, 1.0 / gamma))
    
                # Gradient check
                # for i in range(len(grads)):
                #     if has_nan_in_tensor(grads[i]):
                #         print('Nan value found in gradient: %s!' % t_gs[i].name)
    
                counter += 1
                if counter % 100 is 0:
                    t_stop = time.clock()
                    print('Speed: %.6f' % ((t_stop - t_start) / 100))
                    t_start = t_stop
    
                if counter > max_epoch * len(train_ids):
                    saver.save(sess, model_dir + '/' + str(epoch))
                    print('Training done.')
                elif counter % (len(train_ids) * save_epoch_delay) is 0:
                    saver.save(sess, model_dir + '/' + str(epoch))
                    print('Model saved.')
        # second stage: overall training
        if training_stage == __stage_overall__:
            print('Stage III: train to map input raw into ground truth rgb')
            patches = get_rand_patch_from_file_raw2rgb()
            counter = 0
            t_start = time.clock()
            for raw_in, rgb_gt in patches:
                _, err_overall, sum_all = sess.run(
                    [t_minimizer_overall, t_err_overall, t_sum_all],
                    feed_dict={
                        t_bayer_in: raw_in,
                        t_rgb_gt: rgb_gt
                    }
                )
    
                logger.add_summary(sum_all, counter)
                epoch = int(counter / len(train_ids))
                print('Epoch# %d Counter# %d  Loss= %.7f' % (epoch, counter, err_overall))
                counter += 1
                if counter % 100 is 0:
                    t_stop = time.clock()
                    print('Speed: %.6f' % ((t_stop - t_start) / 100))
                    t_start = t_stop
    
                if counter > max_epoch * len(train_ids):
                    saver.save(sess, model_dir + '/' + str(epoch))
                    print('Training done.')
                elif counter % (len(train_ids) * save_epoch_delay) is 0:
                    saver.save(sess, model_dir + '/' + str(epoch))
                    print('Model saved.')
        # finalization
        logger.close()
        sess.close()
    
    
    def test_half_rgb():
        print('Testing Half RGB reconstruction...')
        sess = tf.Session()
    
        t_vars = tf.global_variables()
    
        # var_names = []
        # for v in t_vars:
        #     var_names.append(v.name)
        #     print(v.name)
    
        saver = tf.train.Saver(t_vars)
    
        if not os.path.exists(model_dir):
            assert 'path not found!'
        model_name_prefix = 'model_checkpoint_path: "'
        with open(os.path.join(model_dir, 'checkpoint')) as ckpt:
            latest_id = ckpt.readline()[len(model_name_prefix):-2]
            saver.restore(sess, os.path.join(model_dir, latest_id))
            print('Model loaded.')
    
        if not os.path.exists(out_dir):
            os.mkdir(out_dir)
    
        patches = get_rand_patch_from_file_raw2raw()
        counter = 0
    
        for raw_in, raw_gt in patches:
            half_rgb_boosted, half_rgb_gt = sess.run(
                [t_half_rgb_boosted, t_half_rgb_gt],
                feed_dict={
                    t_bayer_in: raw_in,
                    t_bayer_gt: raw_gt
                }
            )
            im_cmp = concat((half_rgb_boosted[0], half_rgb_gt[0]))
            # show(im_cmp, str(counter))
            save(im_cmp, (out_dir + '/HALF_%04d.jpg') % counter)
            counter += 1
            if counter >= 20:
                break
    
    
    if __name__ == '__main__':
        # test_half_rgb()
        train()

    1.先说tf.train.Saver()的坑,这个比较严重,其损失是不可挽回的!!!

    由于经常需要迁移学习,需要执行图融合的操作,于是,需要先加载一部分子图然后创建另一部分子图,训练完后保存整个模型。

    问题是:直接采用tf.train.Saver()的话,等效于saver = tf.train.Saver(tf.global_variables())

    在加载子图的时候会报错:因为在子图的checkpoint文件中找不到新创建的子图中的算子,因此需要特别指定要回复的算子,而不是采用tf.global_variables()。

    于是将tf.global_variables()这个替换掉,方案有两种:

    1.直接利用name的prefix进行变量过滤,即对tf.global_variables()得到的变量列表中的部分变量根据其v.name进行剔除,剩下的就是需要加载的变量。

    2.采用tf.contrib.slim直接获取要加载的变量列表,然而这里出现了一个坑:

    slim.get_variables_to_restore(include=include) 中 include 是一个name list,采用正则进行名字匹配,原理是:if v.name.startswith('VAR_NAME_PREFIX'): ADD_TO_LIST(ret)

    于是当你的include list中有conv2d这个变量名称前缀时,所有的conv2d_xxx都会被自动添加到列表中,而且,SLIM很傻逼的不进行查重检查!!!于是你得到的var_list中将会出现重复的

    变量,导致加载模型时报错:at least two of variables have the same name : conv2d_1/bias !!!

    填坑完毕!

    创建saver一定要指定要加载的变量列表,不然不知不觉的可能导致辛辛苦苦训练好的变量(参数)最终没有保存,永远的在结束训练时的内存中消亡了~~~~~

  • 相关阅读:
    统一身份认证(CAS)客户端测试获取信息代码
    常用的java工具类
    windows 批处理(bat)中执行程序后不等待直接退出(cmd中新进程执行程序)
    持续交付的八条原则,你能做到几条?(转)
    灵动标签调用栏目导航技巧
    .net网络编程(2)网络适配器
    Property Value Inheritance Tip(1)
    排序算法补充
    编码参考(Encoding)
    .net网络编程(3)Socket基础
  • 原文地址:https://www.cnblogs.com/thisisajoke/p/9916323.html
Copyright © 2011-2022 走看看