zoukankan      html  css  js  c++  java
  • pytorch-torch参数使用 1.torch.cat(维度串接) 2. torch.backend.cudnn.benchmark(加速优化计算)

    1. torch.cat(data, axis) # data表示输入的数据, axis表示进行串接的维度

    t = Test()
    t.num = 50
    print(t.num)
    
    
    a = torch.tensor([[1, 1]])
    b = torch.tensor([[2, 2]])
    x = []
    x.append(a) # 维度是[1, 1, 2]
    x.append(b) # 维度是[2, 1, 2]
    
    c = torch.cat(x, 0) # 将维度进行串接
    print(c.data.numpy().shape)

    2. torch.backend.cudnn.benchmark (进行优化加速) 如果每次输入都是相同的时候,因为需要搜索计算卷积的最佳方式 ,所以在保证维度不变的情况下,可以持续使用最优的计算方法 

      if opt.preprocess != 'scale_width':  # 如果是规则输入的话,最后的输入值数量可能低于一个batch_size 
                torch.backends.cudnn.benckmark = True

    3. torch.nn.DataParallel (使用多块GPU进行网络的训练)

      if len(gpu_ids) > 0:
            assert(torch.cuda.is_available())
            net.to(gpu_ids[0])
            net = torch.nn.DataParallel(net, gpu_ids)  #gpu_id = [0, 1, 2, 3]
  • 相关阅读:
    hdu2089 不要62
    hdu4734 F(x)
    hdu3555 Bomb
    hdu3652 B-number
    hdu4352 XHXJ's LIS
    CodeForces 55D Beautiful numbers
    数位dp模板
    欧拉函数模板
    UVALive
    常用正则表达 (转)
  • 原文地址:https://www.cnblogs.com/my-love-is-python/p/12750581.html
Copyright © 2011-2022 走看看