zoukankan      html  css  js  c++  java
  • Fine-tune with Pretrained Models

    Gluon版本微调见这里。基于NDarray,类似于Pytorch动态图。而module版本类似于TF,基于Symbol,用的是静态graph。一般静态图用于快速调试见效果,而静态图效率高,速度快,实际中应更多使用。

    本文基于module和symbol。利用imagenet训好的模型来微调caltech-256数据集。首先是制作数据:

    训练集每类随机采样60张图,其余作为验证集。将图像resize成256,并打包成rec文件:

    wget http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar   # 下载解压
    tar -xf 256_ObjectCategories.tar
    
    mkdir -p caltech_256_train_60              # 划分数据集,训练集每类60张图
    for i in 256_ObjectCategories/*; do
        c=`basename $i`
        mkdir -p caltech_256_train_60/$c
        for j in `ls $i/*.jpg | shuf | head -n 60`; do
            mv $j caltech_256_train_60/$c/
        done
    done
    
    python3 im2rec.py --list --recursive caltech-256-60-train caltech_256_train_60/
    python3 im2rec.py --list --recursive caltech-256-60-val 256_ObjectCategories/
    python3 im2rec.py --resize 256 --quality 90 --num-thread 16 caltech-256-60-val 256_ObjectCategories/
    python3 im2rec.py --resize 256 --quality 90 --num-thread 16 caltech-256-60-train caltech_256_train_60/

    这段代码实际上是新建了个文件夹,然后把所有数据剪切出去一个训练集,然后针对两份数据生成lst和rec文件。

    当然也可以解压后直接这样生成:

    tar -xf 256_ObjectCategories.tar
    python3 im2rec.py --list --recursive --train-ratio 0.6 caltech-256-60 256_ObjectCategories
    
    python3 im2rec.py --resize 256 --quality 90 --num-thread 16 caltech-256-60 256_ObjectCategories

    注意在生成的时候可能会发生段错误:

    这是因为在我的电脑上执行resize 256的时候报错,当我把线程数--num-thread改为1的时候就可以了。或者可以resize 更小例如128的时候4线程就ok。生成的文件:

     

    不想自己生成,官方也给出了这些文件的下载:

    import os, sys
    
    if sys.version_info[0] >= 3:
        from urllib.request import urlretrieve
    else:
        from urllib import urlretrieve
    
    def download(url):
        filename = url.split("/")[-1]
        if not os.path.exists(filename):
            urlretrieve(url, filename)
    download('http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec')
    download('http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec')

    然后可以定义data iter:

    import mxnet as mx
    
    def get_iterators(batch_size, data_shape=(3, 224, 224)):
        train = mx.io.ImageRecordIter(
            path_imgrec         = './caltech-256-60-train.rec',
            data_name           = 'data',
            label_name          = 'softmax_label',
            batch_size          = batch_size,
            data_shape          = data_shape,
            shuffle             = True,
            rand_crop           = True,
            rand_mirror         = True)
        val = mx.io.ImageRecordIter(
            path_imgrec         = './caltech-256-60-val.rec',
            data_name           = 'data',
            label_name          = 'softmax_label',
            batch_size          = batch_size,
            data_shape          = data_shape,
            rand_crop           = False,
            rand_mirror         = False)
        return (train, val)

    下载预训练的resnet18权重并载入。

    def get_model(prefix, epoch):
        download(prefix+'-symbol.json')
        download(prefix+'-%04d.params' % (epoch,))
    
    get_model('http://data.mxnet.io/models/imagenet/resnet/50-layers/resnet-18', 0)
    sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-18', 0)

    然后可以开始训练:

    首先定义一个函数替代最后的一层全连接:

    def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='flatten0'):
        """
        symbol: the pretrained network symbol
        arg_params: the argument parameters of the pretrained model
        num_classes: the number of classes for the fine-tune datasets
        layer_name: the layer name before the last fully-connected layer
        """
        all_layers = symbol.get_internals()     # 得到所有层
        net = all_layers[layer_name+'_output']     # 注意这里的操作很反直觉,这句话意思是一直取到名字为layer_name的层
        net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc1')      # 新建一个分类层
        net = mx.symbol.SoftmaxOutput(data=net, name='softmax')              # 输出softmax概率
        new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})     # 除了新的全连接层,载入已有的权重
        return (net, new_args)     # 返回新的网络symbol结构和参数

    symbol是和module搭档的,有了symbol,就可以新建module来喂入数据:

    import logging
    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(level=logging.DEBUG, format=head)
    
    def fit(symbol, arg_params, aux_params, train, val, batch_size, num_gpus):
        devs = [mx.gpu(i) for i in range(num_gpus)]
        mod = mx.mod.Module(symbol=symbol, context=devs)          # 新建一个module
        mod.fit(train, val,     # train和val的 data iter
            num_epoch=8,
            arg_params=arg_params,
            aux_params=aux_params,
            allow_missing=True,
            batch_end_callback = mx.callback.Speedometer(batch_size, 10),        # 每10个批量后打印一次训练速度和评价指标metric的值
            kvstore='device', 
            optimizer='sgd',
            optimizer_params={'learning_rate':0.01},
            initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),
            eval_metric='acc')
        metric = mx.metric.Accuracy()
        return mod.score(val, metric)

    跑起来:

    num_classes = 256
    batch_per_gpu = 16
    num_gpus = 8
    
    (new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes)     # 得到新的symbol和参数
    
    batch_size = batch_per_gpu * num_gpus       # 计算总的批量
    (train, val) = get_iterators(batch_size)       # 根据批量得到data iter
    mod_score = fit(new_sym, new_args, aux_params, train, val, batch_size, num_gpus)    # 训练
    assert mod_score > 0.77, "Low training accuracy." 
  • 相关阅读:
    操作系统 第二章 进程管理
    操作系统 第一章 概述(补充)
    第六次博客作业——团队总结
    专题(十三)watch
    专题(十二)find 查找
    JVM 排查工具介绍(二)Memory Analyzer 堆内存分析工具
    Linux 学习笔记之(二)curl命令
    centos openjdk 11 安装软件包获取方式
    软件工程课程总结
    小黄衫!又一次?
  • 原文地址:https://www.cnblogs.com/king-lps/p/13060039.html
Copyright © 2011-2022 走看看