zoukankan      html  css  js  c++  java
  • ResNet 修改

    https://github.com/tornadomeet/ResNet

    apache 开源项目

    修改如下:

    训练模块

    import argparse,logging,os
    import mxnet as mx
    from symbol_resnet import resnet
    
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    
    
    def multi_factor_scheduler(begin_epoch, epoch_size, step=[60, 75, 90], factor=0.1):
        step_ = [epoch_size * (x-begin_epoch) for x in step if x-begin_epoch > 0]
        return mx.lr_scheduler.MultiFactorScheduler(step=step_, factor=factor) if len(step_) else None
    
    
    def main():
        if args.data_type == "cifar10":
            args.aug_level = 1
            args.num_classes = 10
            # depth should be one of 110, 164, 1001,...,which is should fit (args.depth-2)%9 == 0
            if((args.depth-2)%9 == 0 and args.depth >= 164):
                per_unit = [(args.depth-2)/9]
                filter_list = [16, 64, 128, 256]
                bottle_neck = True
            elif((args.depth-2)%6 == 0 and args.depth < 164):
                per_unit = [(args.depth-2)/6]
                filter_list = [16, 16, 32, 64]
                bottle_neck = False
            else:
                raise ValueError("no experiments done on detph {}, you can do it youself".format(args.depth))
            units = per_unit*3
            symbol = resnet(units=units, num_stage=3, filter_list=filter_list, num_class=args.num_classes,
                            data_type="cifar10", bottle_neck = bottle_neck, bn_mom=args.bn_mom, workspace=args.workspace,
                            memonger=args.memonger)
        elif args.data_type == "imagenet":
            args.num_classes = 3
            if args.depth == 18:
                units = [2, 2, 2, 2]
            elif args.depth == 34:
                units = [3, 4, 6, 3]
            elif args.depth == 50:
                units = [3, 4, 6, 3]
            elif args.depth == 101:
                units = [3, 4, 23, 3]
            elif args.depth == 152:
                units = [3, 8, 36, 3]
            elif args.depth == 200:
                units = [3, 24, 36, 3]
            elif args.depth == 269:
                units = [3, 30, 48, 8]
            else:
                raise ValueError("no experiments done on detph {}, you can do it youself".format(args.depth))
            symbol = resnet(units=units, num_stage=4, filter_list=[64, 256, 512, 1024, 2048] if args.depth >=50
                            else [64, 64, 128, 256, 512], num_class=args.num_classes, data_type="imagenet", bottle_neck = True
                            if args.depth >= 50 else False, bn_mom=args.bn_mom, workspace=args.workspace,
                            memonger=args.memonger)
        else:
             raise ValueError("do not support {} yet".format(args.data_type))
        kv = mx.kvstore.create(args.kv_store)
        devs = mx.cpu() if args.gpus is None else [mx.gpu(int(i)) for i in args.gpus.split(',')]
        epoch_size = max(int(args.num_examples / args.batch_size / kv.num_workers), 1)
        begin_epoch = args.model_load_epoch if args.model_load_epoch else 0
        if not os.path.exists("./model"):
            os.mkdir("./model")
        model_prefix = "model/resnet-{}-{}-{}".format(args.data_type, args.depth, kv.rank)
        checkpoint = mx.callback.do_checkpoint(model_prefix)
        arg_params = None
        aux_params = None
        if args.retrain:
            _, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, args.model_load_epoch)
        if args.memonger:
            import memonger
            symbol = memonger.search_plan(symbol, data=(args.batch_size, 3, 32, 32) if args.data_type=="cifar10"
                                                        else (args.batch_size, 3, 128, 128))
        train = mx.io.ImageRecordIter(
            path_imgrec         = os.path.join(args.data_dir, "cifar10_train.rec") if args.data_type == 'cifar10' else
                                  os.path.join(args.data_dir, "train_256_q90.rec") if args.aug_level == 1
                                  else os.path.join(args.data_dir, "train_480_q90.rec"),
            label_width         = 1,
            data_name           = 'data',
            label_name          = 'softmax_label',
            data_shape          = (3, 32, 32) if args.data_type=="cifar10" else (3, 128, 128),
            batch_size          = args.batch_size,
            pad                 = 4 if args.data_type == "cifar10" else 0,
            fill_value          = 127,  # only used when pad is valid
            rand_crop           = True,
            max_random_scale    = 1.0,  # 480 with imagnet, 32 with cifar10
            min_random_scale    = 1.0 if args.data_type == "cifar10" else 1.0 if args.aug_level == 1 else 0.533,  # 256.0/480.0
            max_aspect_ratio    = 0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 0.25,
            random_h            = 0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 36,  # 0.4*90
            random_s            = 0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 50,  # 0.4*127
            random_l            = 0 if args.data_type == "cifar10" else 0 if args.aug_level == 1 else 50,  # 0.4*127
            max_rotate_angle    = 0 if args.aug_level <= 2 else 10,
            max_shear_ratio     = 0 if args.aug_level <= 2 else 0.1,
            rand_mirror         = True,
            shuffle             = True,
            num_parts           = kv.num_workers,
            part_index          = kv.rank)
        val = mx.io.ImageRecordIter(
            path_imgrec         = os.path.join(args.data_dir, "cifar10_val.rec") if args.data_type == 'cifar10' else
                                  os.path.join(args.data_dir, "val_256_q90.rec"),
            label_width         = 1,
            data_name           = 'data',
            label_name          = 'softmax_label',
            batch_size          = args.batch_size,
            data_shape          = (3, 32, 32) if args.data_type=="cifar10" else (3, 128, 128),
            rand_crop           = False,
            rand_mirror         = False,
            num_parts           = kv.num_workers,
            part_index          = kv.rank)
        model = mx.model.FeedForward(
            ctx                 = devs,
            symbol              = symbol,
            arg_params          = arg_params,
            aux_params          = aux_params,
            num_epoch           = 200 if args.data_type == "cifar10" else 120,
            begin_epoch         = begin_epoch,
            learning_rate       = args.lr,
            momentum            = args.mom,
            wd                  = args.wd,
            optimizer           = 'nag',
            # optimizer          = 'sgd',
            initializer         = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),
            lr_scheduler        = multi_factor_scheduler(begin_epoch, epoch_size, step=[120, 160], factor=0.1)
                                 if args.data_type=='cifar10' else
                                 multi_factor_scheduler(begin_epoch, epoch_size, step=[30, 60, 90], factor=0.1),
            )
        model.fit(
            X                  = train,
            eval_data          = val,
            eval_metric        = ['acc', 'ce'] if args.data_type=='cifar10' else
                                 ['acc','ce', mx.metric.create('top_k_accuracy', top_k = 5)],
            kvstore            = kv,
            batch_end_callback = mx.callback.Speedometer(args.batch_size, args.frequent),
            epoch_end_callback = checkpoint)
        # logging.info("top-1 and top-5 acc is {}".format(model.score(X = val,
        #               eval_metric = ['acc', mx.metric.create('top_k_accuracy', top_k = 5)])))
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser(description="command for training resnet-v2")
        parser.add_argument('--gpus', type=str, default='0', help='the gpus will be used, e.g "0,1,2,3"')
        parser.add_argument('--data-dir', type=str, default='./data/imagenet/', help='the input data directory')
        parser.add_argument('--data-type', type=str, default='imagenet', help='the dataset type')
        parser.add_argument('--list-dir', type=str, default='./',
                            help='the directory which contain the training list file')
        parser.add_argument('--lr', type=float, default=0.1, help='initialization learning reate')
        parser.add_argument('--mom', type=float, default=0.9, help='momentum for sgd')
        parser.add_argument('--bn-mom', type=float, default=0.9, help='momentum for batch normlization')
        parser.add_argument('--wd', type=float, default=0.0001, help='weight decay for sgd')
        parser.add_argument('--batch-size', type=int, default=256, help='the batch size')
        parser.add_argument('--workspace', type=int, default=512, help='memory space size(MB) used in convolution, if xpu '
                            ' memory is oom, then you can try smaller vale, such as --workspace 256')
        parser.add_argument('--depth', type=int, default=50, help='the depth of resnet')
        parser.add_argument('--num-classes', type=int, default=1000, help='the class number of your task')
        parser.add_argument('--aug-level', type=int, default=2, choices=[1, 2, 3],
                            help='level 1: use only random crop and random mirror
    '
                                 'level 2: add scale/aspect/hsv augmentation based on level 1
    '
                                 'level 3: add rotation/shear augmentation based on level 2')
        parser.add_argument('--num-examples', type=int, default=1281167, help='the number of training examples')
        parser.add_argument('--kv-store', type=str, default='device', help='the kvstore type')
        parser.add_argument('--model-load-epoch', type=int, default=0,
                            help='load the model on an epoch using the model-load-prefix')
        parser.add_argument('--frequent', type=int, default=50, help='frequency of logging')
        parser.add_argument('--memonger', action='store_true', default=False,
                            help='true means using memonger to save momory, https://github.com/dmlc/mxnet-memonger')
        parser.add_argument('--retrain', action='store_true', default=False, help='true means continue training')
        args = parser.parse_args()
        logging.info(args)
        main()

    为减小网络大小,将图片全部缩放为128*128大小,平时使用ResNet-50的网络,将num_classes 改为需要的分类数目。

    train acc可以在99.9%水平,val acc 稳定在80%左右

    预测模块

    import numpy as np
    import cv2
    import mxnet as mx
    import argparse
    
    def ch_dev(arg_params, aux_params, ctx):
        new_args = dict()
        new_auxs = dict()
        for k, v in arg_params.items():
            new_args[k] = v.as_in_context(ctx)
        for k, v in aux_params.items():
            new_auxs[k] = v.as_in_context(ctx)
        return new_args, new_auxs
    
    
    
    def predict(img):
        # compute the predict probabilities
        mod.forward(Batch([img]))
        prob = mod.get_outputs()[0].asnumpy()
        # print the top-5
        prob = np.squeeze(prob)
        a = np.argsort(prob)[::-1]
        for i in a[0:3]:
            print('probability=%f, class=%s' %(prob[i], labels[i]))
    
    def main():
        synset = [l.strip() for l in open(args.synset).readlines()]
        # 添加预测
        ctx = mx.gpu(args.gpu)
        sym, arg_params, aux_params = mx.model.load_checkpoint(args.prefix, args.epoch)
        mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
        mod.bind(for_training=False, data_shapes=[('data', (1,3,128,128))],label_shapes=mod._label_shapes)
        mod.set_params(arg_params, aux_params, allow_missing=True)
        
        from collections import namedtuple
        Batch = namedtuple('Batch', ['data'])
        
        if args.lst:
            file = open('instances_test.lst')
            for line in file:
                src = ""
                for i in range(len(line)-1,0,-1):
                    if line[i] == '	':
                        break
                    src += line[i]
                src = src[::-1]
                src = "/mnt/hdfs-data-4/data/jian.yin/ped_thumbnail/instances_test/" + src
                print(src[0:-1])
                
                # convert into format (batch, RGB, width, height)
                img = mx.image.imdecode(open(src[0:-1],'rb').read())
                img = mx.image.imresize(img, 128, 128) # resize
                img = img.transpose((2, 0, 1)) # Channel first
                img = img.expand_dims(axis=0) # batchify
                img = img.astype('float32') # for gpu context
                
                mod.forward(Batch([img]))
                prob = mod.get_outputs()[0].asnumpy()
                # print the top-3
                prob = np.squeeze(prob)
                a = np.argsort(prob)[::-1]
                for i in a[0:3]:
                    print('probability=%f, class=%s' %(prob[i], synset[i]))
                
    
    #             img = cv2.cvtColor(cv2.imread(src[0:-1]), cv2.COLOR_BGR2RGB)
    #             img = cv2.resize(img, (128, 128))  # resize to 224*224 to fit model
    #             img = np.swapaxes(img, 0, 2)
    #             img = np.swapaxes(img, 1, 2)  # change to (c, h,w) order
    #             img = img[np.newaxis, :]  # extend to (n, c, h, w)
    
    #             ctx = mx.gpu(args.gpu)
    #             sym, arg_params, aux_params = mx.model.load_checkpoint(args.prefix, args.epoch)
    #             arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
    #             arg_params["data"] = mx.nd.array(img, ctx)
    #             arg_params["softmax_label"] = mx.nd.empty((1,), ctx)
    #             exe = sym.bind(ctx, arg_params ,args_grad=None, grad_req="null", aux_states=aux_params)
    #             exe.forward(is_train=False)
    
    #             prob = np.squeeze(exe.outputs[0].asnumpy())
    #             pred = np.argsort(prob)[::-1]
    #             print("Top1 result is: ", synset[pred[0]])
    #             # print("Top5 result is: ", [synset[pred[i]] for i in range(5)])
            file.close()
    
        
    
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser(description="use pre-trainned resnet model to classify one image")
        parser.add_argument('--img', type=str, default='test.jpg', help='input image for classification')
        # add --lst
        parser.add_argument('--lst',type=str,default='test.lst',help="input image's lst for classification")
        parser.add_argument('--gpu', type=int, default=0, help='the gpu id used for predict')
        parser.add_argument('--synset', type=str, default='synset.txt', help='file mapping class id to class name')
        parser.add_argument('--prefix', type=str, default='resnet-50', help='the prefix of the pre-trained model')
        parser.add_argument('--epoch', type=int, default=0, help='the epoch of the pre-trained model')
        args = parser.parse_args()
        main()

    添加了--lst可选参数,可以批处理序列化文件预测。

    原文预测模块效率较低,改用mxnet标准的predict写法:https://mxnet.incubator.apache.org/tutorials/python/predict_image.html

    添加一个脚本,防止忘记一些参数的写法:

    #!/usr/bin/
    python -u predict.py --lst instances_test.lst --prefix resnet-50 --synset ped_thumbnail.txt --gpu 0

    记得运行的时候添加管道命令 > 

    /mnt/1/385_328_428_402_6.jpg
    probability=0.994927, class=1 Cyclist
    probability=0.003335, class=2 Others
    probability=0.001739, class=0 Pedestrian
    /mnt2/439_359_481_428_0.jpg
    probability=0.994793, class=2 Others
    probability=0.002817, class=0 Pedestrian
    probability=0.002390, class=1 Cyclist
    /mnt/2/619_337_658_401_16.jpg
    probability=0.992218, class=2 Others
    probability=0.007275, class=1 Cyclist
    probability=0.000507, class=0 Pedestrian
    /mnt1/511_288_561_385_1.jpg
    probability=0.997837, class=1 Cyclist
    probability=0.001525, class=0 Pedestrian
    probability=0.000638, class=2 Others

    分析预测结果

    可以先把各种分类的路径记录下来。

    import itertools
    import numpy as np
    import matplotlib.pyplot as plt
    
    from sklearn import svm, datasets
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import confusion_matrix
    
    
    file = open('myPredict.txt')
    
    cnt = 0
    
    true = []
    pred = []
    
    for line in file:
        if cnt%4 == 0:
            pos =-1
            for i in range(len(line)-1,-1,-1):
                if line[i]=='/':
                    pos = i - 1
                    break
            true.append(int(line[pos]))
        if cnt%4 == 1:
            pos = -1
            for i in range(len(line)-1,-1,-1):
                if line[i] == ' ':
                    pos = i - 1
                    break
            pred.append(int(line[pos]))
        cnt+=1
    
    print(true)
    print(pred)
    print(confusion_matrix(true,pred))
    
    file.close()
    
    zero_zero = []
    zero_one = []
    zero_two = []
    one_zero = []
    one_one = []
    one_two = []
    two_zero = []
    two_one = []
    two_two = []
    
    cnt = 0
    pos = 0
    file = open('myPredict.txt')
    
    for line in file:
        if cnt%4==0:
            if true[pos] == 0 and pred[pos] == 0:
                zero_zero.append(line)
            if true[pos] == 0 and pred[pos] == 1:
                zero_one.append(line)
            if true[pos] == 0 and pred[pos] == 2:
                zero_two.append(line)
            if true[pos] == 1 and pred[pos] == 0:
                one_zero.append(line)
            if true[pos] == 1 and pred[pos] == 1:
                one_one.append(line)
            if true[pos] == 1 and pred[pos] == 2:
                one_two.append(line)
            if true[pos] == 2 and pred[pos] == 0:
                two_zero.append(line)
            if true[pos] == 2 and pred[pos] == 1:
                two_one.append(line)
            if true[pos] == 2 and pred[pos] == 2:
                two_two.append(line)
            pos+=1
        cnt+=1
    file.close()
    
    print(len(zero_one)+len(zero_two)+len(one_zero)+len(one_two)+len(two_zero)+len(two_one))
    
    # 0 - 0
    write_zero_zero = open('zero_zero.txt','w')
    for i in range(len(zero_zero)):
        write_zero_zero.write(zero_zero[i])
    write_zero_zero.close()
    
    # 0 - 1
    write_zero_one = open('zero_one.txt','w')
    for i in range(len(zero_one)):
        write_zero_one.write(zero_one[i])
    write_zero_one.close()
    
    # 0 - 2
    write_zero_two = open('zero_two.txt','w')
    for i in range(len(zero_two)):
        write_zero_two.write(zero_two[i])
    write_zero_two.close()
    
    # 1 - 0
    write_one_zero = open('one_zero.txt','w')
    for i in range(len(one_zero)):
        write_one_zero.write(one_zero[i])
    write_one_zero.close()
    
    # 1 - 1
    write_one_one = open('one_one.txt','w')
    for i in range(len(one_one)):
        write_one_one.write(one_one[i])
    write_one_one.close()
    
    # 1 - 2
    write_one_two = open('one_two.txt','w')
    for i in range(len(one_two)):
        write_one_two.write(one_two[i])
    write_one_two.close()
    
    # 2 - 0
    write_two_zero = open('two_zero.txt','w')
    for i in range(len(two_zero)):
        write_two_zero.write(two_zero[i])
    write_two_zero.close()
    
    # 2 - 1
    write_two_one = open('two_one.txt','w')
    for i in range(len(two_one)):
        write_two_one.write(two_one[i])
    write_two_one.close()
    
    # 2 - 2
    write_two_two = open('two_two.txt','w')
    for i in range(len(two_two)):
        write_two_two.write(two_two[i])
    write_two_two.close()

    混淆矩阵如下:

  • 相关阅读:
    Sublime 官方安装方法
    Notepad2、Sublime_text带图标的右键快捷打开方式
    创业公司如何实施敏捷开发
    如果有人让你推荐编程技术书,请叫他看这个列表
    Spring cron表达式详解
    Spring定时任务的几种实现
    spring注解方式 idea报could not autowire,eclipse却没有问题
    mysql处理海量数据时的一些优化查询速度方法
    Hexo重装小结
    修改JAVA代码,需要重启Tomcat的原因
  • 原文地址:https://www.cnblogs.com/TreeDream/p/10102626.html
Copyright © 2011-2022 走看看