zoukankan      html  css  js  c++  java
  • Accelerating Deep Learning by Focusing on the Biggest Losers

    Accelerating Deep Learning by Focusing on the Biggest Losers

    思想很简单, 在训练网络的时候, 每个样本都会产生一个损失(mathcal{L}(f(x_i),y_i)), 训练的模式往往是批训练, 将一个批次(sum_i mathcal{L}(f(x_i),y_i))所产生的损失的梯度都传回去, 然后更新参数. 本文认为, 有些样本((x_i,y_i))由于重复度高, 网络很高能够识别, 使得对应的(mathcal{L}(f(x_i),y_i))相对较小, 所以设计了一种机制, 使得损失较大的样本有大概率被选中, 而不重要的样本不被选中, 以此来降低计算时间. 实验证明, 这种方法能够在保持准确率不变的前提下降低训练时间.

    在这里插入图片描述

    相关工作

    作者说这个算法首先是由Are Loss Functions All the Same?提出的, 但是这篇文章只是讲了hinge loss的优势和对其它损失函数的分析.

    作者说最相关的文章是Not All Samples Are Created Equal: Deep Learning with Importance Sampling, 这篇文章是从预处理(虽然也是要算loss的)的角度出发的, 理论部分较本文多一些.

    主要内容

    在这里插入图片描述

    在这里插入图片描述
    算法1的思路是很清晰的, 主要困扰在算法2概率的计算上. 假设我们以及计算了(n)个样本的损失, 我们将其存储起来, 假设下一个样本的损失是(mathcal{L}_c), 如果这(n)个样本中有(k)个样本的损失均小于(mathcal{L}_c), 则改样本被选中的概率是:

    [max {(k/n)^eta, s} ]

    其中(sin[0,1])是人为设置的, 保证每个样本都有被选中的可能.
    我们还可以设置一个最大的长度(r), 将以往的损失存储在一个双栈中, 当(n=r)的时候,存储下一个损失的同时会抛弃第一个损失, 这么做能在一定程度上减少计算量.

    graph LR A[样本x] --> C(网络f) C --> D[损失l] D--更新-->E[损失库] D-->F[计算概率] F-->G(形成batch) G--反向传递-->C E-->F

    从最开始的图中, 第二列就是表示这个算法, 第三列是在此基础上对前向传递进行一些处理. 直接的是, 每隔(n)次epoches更新一次损失, 然后中间的n-1次不更新损失, 直接用旧的损失对样本选择(应该是直接在传入网络就将样本选择好否则就不能降低时间了).

    在随机算法中, 有单通道选择样本的一个算法, 但是这个算法只用于选择一个. 所以如果选择很多这个算法就没用了, 感觉一次性选择很多个不好弄.

    代码

    因为条件限制, 代码并没有测试过, 论文也给出了很棒的代码.

    """
    OptInput.py
    纯粹是为了便于交互一些, 直接用argparse也可以
    """
    
    
    class Unit:
    
        def __init__(self, command, type=str,
                        default=None):
            if default is None:
                default = type()
            self.command = command
            self.type = type
            self.default = default
    
    class Opi:
        """
        >>> parser = Opi()
        >>> parser.add_opt(command="lr", type=float)
        >>> parser.add_opt(command="epochs", type=int)
        """
        def __init__(self):
            self.store = []
            self.infos = {}
    
        def add_opt(self, **kwargs):
            self.store.append(
                Unit(**kwargs)
            )
    
        def acquire(self):
            s = "Acquire args {0.command} [" 
                "type:{0.type.__name__} " 
                "default:{0.default}] : "
            for unit in self.store:
                while True:
                    inp = input(s.format(
                        unit
                    ))
                    try:
                        if inp: #若有输入
                            inp = unit.type(inp)
                        else:
                            inp = unit.default
                        self.infos.update(
                            {unit.command:inp}
                        )
                        self.__setattr__(unit.command, inp)
                        break
                    except:
                        print("Type {0} should be given".format(
                            unit.type.__name__
                        ))
    
    
    if __name__ == "__main__":
        parser = Opi()
        parser.add_opt(command = "x", type=int)
        parser.add_opt(command="y", type=str)
        parser.acquire()
        print(parser.infos)
        print(parser.x)
    
    '''
    calcprob.py
    计算概率
    '''
    
    
    
    
    import collections
    
    
    
    class Calcprob:
        def __init__(self, beta, sample_min, max_len=3000):
            assert 0. <= sample_min <= 1., "Invalid sample_min"
            assert beta > 0, "Invalid beta"
            self.beta = beta
            self.sample_min = sample_min
            self.max_len = max_len
            self.history = collections.deque(maxlen=max_len)
            self.num_slot = 1000
            self.hist = [0] * self.num_slot
            self.count = 0
    
        def update_history(self, losses):
            """
            BoundedHistogram
            :param losses:
            :return:
            """
            for loss in losses:
                assert loss > 0
                if self.count is self.max_len:
                    loss_old = self.history.popleft()
                    slot_old = int(loss_old * self.num_slot) % self.num_slot
                    self.hist[slot_old] -= 1
                else:
                    self.count += 1
                    self.history.append(loss)
                slot = int(loss * self.num_slot) % self.num_slot
                self.hist[slot] += 1
    
        def get_probability(self, loss):
            assert loss > 0
            slot = int(loss * self.num_slot) % self.num_slot
            prob = sum(self.hist[:slot]) / self.count
            assert isinstance(prob, float), "int division error..."
            return prob ** self.beta
    
        def calc_probability(self, losses):
            if isinstance(losses, float):
                losses =  (losses, )
            self.update_history(losses)
            probs = (
                max(
                    self.get_probability(loss),
                    self.sample_min
                )
                for loss in losses
            )
            return probs
    
        def __call__(self, losses):
            return self.calc_probability(losses)
    
    
    if __name__ == "__main__":
        pass
    
    
    
    '''
    selector.py
    '''
    
    
    import calcprob
    import numpy as np
    
    
    class Selector:
    
        def __init__(self, batch_size,
                     beta, sample_min, max_len=3000):
            self.batch_size = batch_size
            self.calcprob = calcprob.Calcprob(beta,
                                              sample_min,
                                              max_len)
            self.reset()
    
        def backward(self):
            loss = sum(self.batch)
            loss.backward()
            self.reset()
    
        def reset(self):
            self.batch = []
            self.length = 0.
    
        def select(self, losses):
            probs = self.calcprob(losses)
            for i, prob in enumerate(probs):
                if np.random.rand() < prob:
                    self.batch.append(losses[i])
                    self.length += 1
                    if self.length >= self.batch_size:
                        self.backward()
    
        def __call__(self, losses):
            self.select(losses)
    
    
    
    '''
    main.py
    '''
    
    import torch
    import torch.nn as nn
    import torchvision
    import torchvision.transforms as transforms
    import numpy as np
    import os
    
    
    import selector
    
    
    
    
    
    class Train:
    
        def __init__(self, model, lossfunc,
                     bpsize, beta, sample_min, max_len=3000,
                     lr=0.01, momentum=0.9, weight_decay=0.0001):
            self.net = self.choose_net(model)
            self.criterion = self.choose_lossfunc(lossfunc)
            self.opti = torch.optim.SGD(self.net.parameters(),
                                        lr=lr, momentum=momentum,
                                        weight_decay=weight_decay)
            self.selector = selector.Selector(bpsize, beta,
                                              sample_min, max_len)
            self.gpu()
            self.generate_path()
            self.acc_rates = []
            self.errors = []
    
        def choose_net(self, model):
            net = getattr(
                torchvision.models,
                model,
                None
            )
            if net is None:
                raise ValueError("no such model")
            return net()
    
        def choose_lossfunc(self, lossfunc):
            lossfunc = getattr(
                nn,
                lossfunc,
                None
            )
            if lossfunc is None:
                raise ValueError("no such lossfunc")
            return lossfunc
    
    
    
        def gpu(self):
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            if torch.cuda.device_count() > 1:
                print("Let'us use %d GPUs" % torch.cuda.device_count())
                self.net = nn.DataParallel(self.net)
            self.net = self.net.to(self.device)
    
    
    
        def generate_path(self):
            """
            生成保存数据的路径
            :return:
            """
            try:
                os.makedirs('./paras')
                os.makedirs('./logs')
                os.makedirs('./infos')
            except FileExistsError as e:
                pass
            name = self.net.__class__.__name__
            paras = os.listdir('./paras')
            logs = os.listdir('./logs')
            infos = os.listdir('./infos')
            number = max((len(paras), len(logs), len(infos)))
            self.para_path = "./paras/{0}{1}.pt".format(
                name,
                number
            )
    
            self.log_path = "./logs/{0}{1}.txt".format(
                name,
                number
            )
            self.info_path = "./infos/{0}{1}.npy".format(
                name,
                number
            )
    
    
        def log(self, strings):
            """
            运行日志
            :param strings:
            :return:
            """
            # a 往后添加内容
            with open(self.log_path, 'a', encoding='utf8') as f:
                f.write(strings)
    
        def save(self):
            """
            保存网络参数
            :return:
            """
            torch.save(self.net.state_dict(), self.para_path)
    
        def derease_lr(self, multi=0.96):
            """
            降低学习率
            :param multi:
            :return:
            """
            self.opti.param_groups[0]['lr'] *= multi
    
    
        def train(self, trainloder, epochs=50):
            data_size = len(trainloder) * trainloder.batch_size
            part = int(trainloder.batch_size / 2)
            for epoch in range(epochs):
                running_loss = 0.
                total_loss = 0.
                acc_count = 0.
                if (epoch + 1) % 8 is 0:
                    self.derease_lr()
                    self.log(#日志记录
                        "learning rate change!!!
    "
                    )
                for i, data in enumerate(trainloder):
                    imgs, labels = data
                    imgs = imgs.to(self.device)
                    labels = labels.to(self.device)
                    out = self.net(imgs)
                    _, pre = torch.max(out, 1)  #判断是否判断正确
                    acc_count += (pre == labels).sum().item() #加总对的个数
    
                    losses = (
                        self.criterion(out[i], labels[i])
                        for i in range(len(labels))
                    )
    
                    self.opti.zero_grad()
                    self.selector(losses) #选择
                    self.opti.step()
    
                    running_loss += sum(losses).item()
    
                    if (i+1) % part is 0:
                        strings = "epoch {0:<3} part {1:<5} loss: {2:<.7f}
    ".format(
                            epoch, i, running_loss / part
                        )
                        self.log(strings)#日志记录
                        total_loss += running_loss
                        running_loss = 0.
                self.acc_rates.append(acc_count / data_size)
                self.errors.append(total_loss / data_size)
                self.log( #日志记录
                    "Accuracy of the network on %d train images: %d %%
    " %(
                        data_size, acc_count / data_size * 100
                    )
                )
                self.save() #保存网络参数
            #保存一些信息画图用
            np.save(self.info_path, {
                'acc_rates': np.array(self.acc_rates),
                'errors': np.array(self.errors)
            })
    
    
    
    
    if __name__ == "__main__":
    
        import OptInput
        args = OptInput.Opi()
        args.add_opt(command="model", default="resnet34")
        args.add_opt(command="lossfunc", default="CrossEntropyLoss")
        args.add_opt(command="bpsize", default=32)
        args.add_opt(command="beta", default=0.9)
        args.add_opt(command="sample_min", default=0.3)
        args.add_opt(command="max_len", default=3000)
        args.add_opt(command="lr", default=0.001)
        args.add_opt(command="momentum", default=0.9)
        args.add_opt(command="weight_decay", default=0.0001)
    
        args.acquire()
    
        root = "C:/Users/pkavs/1jupiterdata/data"
    
        trainset = torchvision.datasets.CIFAR10(root=root, train=True,
                                              download=False,
                                              transform=transforms.Compose(
                                                  [transforms.Resize(224),
                                                   transforms.ToTensor(),
                                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
                                              ))
    
        train_loader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                                  shuffle=True, num_workers=8,
                                                   pin_memory=True)
    
    
    
        dog = Train(**args.infos)
        dog.train(train_loader, epochs=1000)
    
    
    
    
    
    
  • 相关阅读:
    进程间的通讯(IPC)方式
    进程间通信IPC之--共享内存
    TMDS协议
    HDMI接口与协议
    HDMI的CEC是如何控制外围互联设备的
    SVN并行开发管理策略
    关于 javascript event flow 的一个bug
    H面试程序(15): 冒泡排序法
    android应用如何启动另外一个apk应用
    [置顶] 一份诚恳的互联网找工作总结和感想(附:怎样花两年时间去面试一个人)
  • 原文地址:https://www.cnblogs.com/MTandHJ/p/12318805.html
Copyright © 2011-2022 走看看