zoukankan      html  css  js  c++  java
  • Some helper functions for PyTorch

      1 '''Some helper functions for PyTorch, including:
      2     - get_mean_and_std: calculate the mean and std value of dataset.
      3     - msr_init: net parameter initialization.
      4     - progress_bar: progress bar mimic xlua.progress.
      5 '''
      6 import os
      7 import sys
      8 import time
      9 import math
     11 import torch.nn as nn
     12 import torch.nn.init as init
     15 def get_mean_and_std(dataset):
     16     '''Compute the mean and std value of dataset.'''
     17     dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
     18     mean = torch.zeros(3)
     19     std = torch.zeros(3)
     20     print('==> Computing mean and std..')
     21     for inputs, targets in dataloader:
     22         for i in range(3):
     23             mean[i] += inputs[:,i,:,:].mean()
     24             std[i] += inputs[:,i,:,:].std()
     25     mean.div_(len(dataset))
     26     std.div_(len(dataset))
     27     return mean, std
     29 def init_params(net):
     30     '''Init layer parameters.'''
     31     for m in net.modules():
     32         if isinstance(m, nn.Conv2d):
     33             init.kaiming_normal(m.weight, mode='fan_out')
     34             if m.bias:
     35                 init.constant(m.bias, 0)
     36         elif isinstance(m, nn.BatchNorm2d):
     37             init.constant(m.weight, 1)
     38             init.constant(m.bias, 0)
     39         elif isinstance(m, nn.Linear):
     40             init.normal(m.weight, std=1e-3)
     41             if m.bias:
     42                 init.constant(m.bias, 0)
     45 _, term_width = os.popen('stty size', 'r').read().split()
     46 term_width = int(term_width)
     48 TOTAL_BAR_LENGTH = 65.
     49 last_time = time.time()
     50 begin_time = last_time
     51 def progress_bar(current, total, msg=None):
     52     global last_time, begin_time
     53     if current == 0:
     54         begin_time = time.time()  # Reset for new bar.
     56     cur_len = int(TOTAL_BAR_LENGTH*current/total)
     57     rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
     59     sys.stdout.write(' [')
     60     for i in range(cur_len):
     61         sys.stdout.write('=')
     62     sys.stdout.write('>')
     63     for i in range(rest_len):
     64         sys.stdout.write('.')
     65     sys.stdout.write(']')
     67     cur_time = time.time()
     68     step_time = cur_time - last_time
     69     last_time = cur_time
     70     tot_time = cur_time - begin_time
     72     L = []
     73     L.append('  Step: %s' % format_time(step_time))
     74     L.append(' | Tot: %s' % format_time(tot_time))
     75     if msg:
     76         L.append(' | ' + msg)
     78     msg = ''.join(L)
     79     sys.stdout.write(msg)
     80     for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
     81         sys.stdout.write(' ')
     83     # Go back to the center of the bar.
     84     for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
     85         sys.stdout.write('')
     86     sys.stdout.write(' %d/%d ' % (current+1, total))
     88     if current < total-1:
     89         sys.stdout.write('
     90     else:
     91         sys.stdout.write('
     92     sys.stdout.flush()
     94 def format_time(seconds):
     95     days = int(seconds / 3600/24)
     96     seconds = seconds - days*3600*24
     97     hours = int(seconds / 3600)
     98     seconds = seconds - hours*3600
     99     minutes = int(seconds / 60)
    100     seconds = seconds - minutes*60
    101     secondsf = int(seconds)
    102     seconds = seconds - secondsf
    103     millis = int(seconds*1000)
    105     f = ''
    106     i = 1
    107     if days > 0:
    108         f += str(days) + 'D'
    109         i += 1
    110     if hours > 0 and i <= 2:
    111         f += str(hours) + 'h'
    112         i += 1
    113     if minutes > 0 and i <= 2:
    114         f += str(minutes) + 'm'
    115         i += 1
    116     if secondsf > 0 and i <= 2:
    117         f += str(secondsf) + 's'
    118         i += 1
    119     if millis > 0 and i <= 2:
    120         f += str(millis) + 'ms'
    121         i += 1
    122     if f == '':
    123         f = '0ms'
    124     return f
  • 相关阅读:
    git merge远程合并
    入门级实操教程!从概念到部署,全方位了解K8S Ingress!
    Java 表达式之谜:为什么 index 增加了两次?
    Vavr Option:Java Optional 的另一个选项
    一文详解 Java 的八大基本类型!
    如何找到真正的 public 方法
  • 原文地址:https://www.cnblogs.com/jiangkejie/p/11201133.html
Copyright © 2011-2022 走看看