1 '''Some helper functions for PyTorch, including: 2 - get_mean_and_std: calculate the mean and std value of dataset. 3 - msr_init: net parameter initialization. 4 - progress_bar: progress bar mimic xlua.progress. 5 ''' 6 import os 7 import sys 8 import time 9 import math 10 11 import torch.nn as nn 12 import torch.nn.init as init 13 14 15 def get_mean_and_std(dataset): 16 '''Compute the mean and std value of dataset.''' 17 dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 18 mean = torch.zeros(3) 19 std = torch.zeros(3) 20 print('==> Computing mean and std..') 21 for inputs, targets in dataloader: 22 for i in range(3): 23 mean[i] += inputs[:,i,:,:].mean() 24 std[i] += inputs[:,i,:,:].std() 25 mean.div_(len(dataset)) 26 std.div_(len(dataset)) 27 return mean, std 28 29 def init_params(net): 30 '''Init layer parameters.''' 31 for m in net.modules(): 32 if isinstance(m, nn.Conv2d): 33 init.kaiming_normal(m.weight, mode='fan_out') 34 if m.bias: 35 init.constant(m.bias, 0) 36 elif isinstance(m, nn.BatchNorm2d): 37 init.constant(m.weight, 1) 38 init.constant(m.bias, 0) 39 elif isinstance(m, nn.Linear): 40 init.normal(m.weight, std=1e-3) 41 if m.bias: 42 init.constant(m.bias, 0) 43 44 45 _, term_width = os.popen('stty size', 'r').read().split() 46 term_width = int(term_width) 47 48 TOTAL_BAR_LENGTH = 65. 49 last_time = time.time() 50 begin_time = last_time 51 def progress_bar(current, total, msg=None): 52 global last_time, begin_time 53 if current == 0: 54 begin_time = time.time() # Reset for new bar. 55 56 cur_len = int(TOTAL_BAR_LENGTH*current/total) 57 rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 58 59 sys.stdout.write(' [') 60 for i in range(cur_len): 61 sys.stdout.write('=') 62 sys.stdout.write('>') 63 for i in range(rest_len): 64 sys.stdout.write('.') 65 sys.stdout.write(']') 66 67 cur_time = time.time() 68 step_time = cur_time - last_time 69 last_time = cur_time 70 tot_time = cur_time - begin_time 71 72 L = [] 73 L.append(' Step: %s' % format_time(step_time)) 74 L.append(' | Tot: %s' % format_time(tot_time)) 75 if msg: 76 L.append(' | ' + msg) 77 78 msg = ''.join(L) 79 sys.stdout.write(msg) 80 for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 81 sys.stdout.write(' ') 82 83 # Go back to the center of the bar. 84 for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 85 sys.stdout.write('') 86 sys.stdout.write(' %d/%d ' % (current+1, total)) 87 88 if current < total-1: 89 sys.stdout.write(' ') 90 else: 91 sys.stdout.write(' ') 92 sys.stdout.flush() 93 94 def format_time(seconds): 95 days = int(seconds / 3600/24) 96 seconds = seconds - days*3600*24 97 hours = int(seconds / 3600) 98 seconds = seconds - hours*3600 99 minutes = int(seconds / 60) 100 seconds = seconds - minutes*60 101 secondsf = int(seconds) 102 seconds = seconds - secondsf 103 millis = int(seconds*1000) 104 105 f = '' 106 i = 1 107 if days > 0: 108 f += str(days) + 'D' 109 i += 1 110 if hours > 0 and i <= 2: 111 f += str(hours) + 'h' 112 i += 1 113 if minutes > 0 and i <= 2: 114 f += str(minutes) + 'm' 115 i += 1 116 if secondsf > 0 and i <= 2: 117 f += str(secondsf) + 's' 118 i += 1 119 if millis > 0 and i <= 2: 120 f += str(millis) + 'ms' 121 i += 1 122 if f == '': 123 f = '0ms' 124 return f