zoukankan      html  css  js  c++  java
  • Pytorch训练模型常用操作

    One-hot编码

    将标签转换为one-hot编码形式

    def to_categorical(y, num_classes):
        """ 1-hot encodes a tensor """
        new_y = torch.eye(num_classes)[y.cpu().data.numpy(), ]
        if (y.is_cuda):
            return new_y.cuda()
        return new_y
    
    • 示例
    >>> y = np.array([1,2,3])
    >>> y
    array([1, 2, 3])
    >>> torch.eye(4)[y,]
    tensor([[0., 1., 0., 0.],
            [0., 0., 1., 0.],
            [0., 0., 0., 1.]])
    
    >>> y
    array([[1, 2, 2],
           [1, 2, 3]])
    >>> torch.eye(4)[y,]
    tensor([[[0., 1., 0., 0.],
             [0., 0., 1., 0.],
             [0., 0., 1., 0.]],
    
            [[0., 1., 0., 0.],
             [0., 0., 1., 0.],
             [0., 0., 0., 1.]]])
    >>> torch.eye(4)[y]
    tensor([1., 1., 0.])
    

    分别初始化

    def weights_init(m):
      classname = m.__class__.__name__
      if classname.find('Conv2d') != -1:
          torch.nn.init.xavier_normal_(m.weight.data)
          torch.nn.init.constant_(m.bias.data, 0.0)
      elif classname.find('Linear') != -1:
          torch.nn.init.xavier_normal_(m.weight.data)
          torch.nn.init.constant_(m.bias.data, 0.0)
    
    classifier = classifier.apply(weights_init)
    

    checkpoint检查是否接着训练

    try:
        checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0
    

    根据迭代次数调整学习率

    
    def bn_momentum_adjust(m, momentum):
        if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
            m.momentum = momentum
    
    lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
    log_string('Learning rate:%f' % lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
    if momentum < 0.01:
        momentum = 0.01
    print('BN momentum updated to: %f' % momentum)
    classifier = classifier.apply(lambda x: bn_momentum_adjust(x, momentum))
    classifier = classifier.train()
    
    

    批量数据维度不一致

    自定义torch.utils.data.Dataloader(dataset, collate_fn=collate_fn)中的collate_fn

    def my_collate_fn(batch_data):
        """
        descriptions: 对齐批量数据维度, [(data, label),(data, label)...]转化成([data, data...],[label,label...])
        :param batch_data:  list,[(data, label),(data, label)...]
        :return: tuple, ([data, data...],[label,label...])
        """
        batch_data.sort(key=lambda x: len(x[0]), reverse=False)  # 按照数据长度升序排序
        data_list = []
        cls_list = []
        label_list = []
        min_len = len(batch_data[0][0])
        for batch in range(0, len(batch_data)):
            data = batch_data[batch][0]
            cls = batch_data[batch][1]
            label = batch_data[batch][2]
    
            choice = np.random.choice(range(0, len(data)), min_len, replace=False)
            data = data[choice, :]
            label = label[choice]
    
            data_list.append(data)
            cls_list.append(cls)
            label_list.append(label)
    
        data_tensor = torch.tensor(data_list, dtype=torch.float32)
        cls_tensor = torch.tensor(cls_list, dtype=torch.float32)
        label_tensor = torch.tensor(label_list, dtype=torch.float32)
        data_copy = (data_tensor, cls_tensor, label_tensor)
        return data_copy
    

    分割标签分配不同权值

    labelweights = np.zeros(N_Class)
    tmp, _ = np.histogram(labels, range(N_Class+ 1))
    labelweights += tmp
    
    labelweights = labelweights.astype(np.float32)
    labelweights = labelweights / np.sum(labelweights)
    labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0)
    print(labelweights)
    
    class get_loss(nn.Module):
        def __init__(self):
            super(get_loss, self).__init__()
    
        def forward(self, pred, target, trans_feat, weight=None):
            total_loss = F.nll_loss(pred, target, weight=weight)
    
            return total_loss
    
    
  • 相关阅读:
    剑指offer 顺时针打印矩阵
    剑指offer队列中的最大值
    固定顶部指定div不滑动
    调整圆环统计图格式
    补插一个MUI中UI组件示例地址
    统计图左右滑动
    mui集成百度ECharts的统计图表以及清空释放图表
    页面ajax自带的访问后台时,正在加载中
    js弹出div层内容(按回退键关闭div层及遮罩)
    地图经纬度定位不准
  • 原文地址:https://www.cnblogs.com/xiaxuexiaoab/p/15409692.html
Copyright © 2011-2022 走看看