zoukankan      html  css  js  c++  java
  • torchnet+VGG16计算patch之间相似度

    torchnet+VGG16计算patch之间相似度

    本来打算使用VGG实现siamese CNN的,但是没想明白怎么使用torchnet对模型进行微调。。。所以只好把VGG的卷积层单独做一个数据预处理模块,后面跟一个网络,将两个VGG输出的结果输入该网络中,仅训练这个浅层网络。

    数据:使用了MOTChallenge数据库MOT16-02中的pedestrian

    代码:

    1. -- --------------------------------------------------------------------------------------- 
    2. -- 读取MOT16-02数据集的groundtruth,分成训练集和测试集 
    3. -- --------------------------------------------------------------------------------------- 
    4. require 'torch' 
    5. require 'cutorch' 
    6. torch.setdefaulttensortype('torch.FloatTensor'
    7. data_type = 'torch.CudaTensor' -- 设置数据类型,不适用GPU可以设置为torch.FloatTensor 
    8.  
    9. require 'image' 
    10. local datapath = '/home/zwzhou/programFiles/2DMOT2015/MOT16/train/MOT16-02/' 
    11. local tmp = image.load(datapath .. 'img1/000001.jpg',3,'byte'
    12. local width = tmp:size(3
    13. local height = tmp:size(2
    14. local num = 600 
    15. local imgs = torch.Tensor(num,3,height,width) 
    16.  
    17. local file,_ = io.open('imgs.t7'
    18. if not file then 
    19. for i=1,num do -- 读取视频帧 
    20. imgs[i]=image.load(datapath .. 'img1/' .. string.format('%06d.jpg',i)) 
    21. end 
    22. torch.save('imgs.t7',imgs) 
    23. else 
    24. imgs = torch.load('imgs.t7'
    25. end 
    26.  
    27. require'sys' 
    28. local gt_path = datapath .. 'gt/gt.txt' 
    29. local gt_info={} 
    30. local i=0 
    31. for line in io.lines(gt_path) do -- pedestrians的patch信息 
    32. local v=sys.split(line,','
    33. if tonumber(v[7]) ==1 and tonumber(v[9]) > 0.8 then -- 筛选有效的patch,是pedestrian且可见度>0.8 
    34. table.insert(gt_info,{tonumber(v[1]),tonumber(v[2]),tonumber(v[3]),tonumber(v[4]),tonumber(v[5]),tonumber(v[6])}) 
    35. -- 对应的是frame index,track index, x, y, w, h 
    36. end 
    37. end 
    38. -- 构建样本对,这里主要是为了正负样本个数相同,每个pedestrian选取25个相同id的patch,25个不同id的patch 
    39. local pairwise={} 
    40. for i=1,#gt_info do 
    41. local count=0 
    42. local iter=0 
    43. repeat  
    44. local j=torch.ceil(torch.rand(1)*(#gt_info))[1
    45. if gt_info[i][2] == gt_info[j][2] then  
    46. count=count+1 
    47. table.insert(pairwise,{i,j}) 
    48. end 
    49. iter=iter+1 
    50. until(count >25 or iter>100
    51. repeat  
    52. local j=torch.ceil(torch.rand(1)*#gt_info)[1
    53. if gt_info[i][2] ~= gt_info[j][2] then  
    54. count=count-1 
    55. table.insert(pairwise,{i,j}) 
    56. end 
    57. until(count <0
    58. end 
    59.  
    60. local function cast(x) return x:type(data_type) end -- 类型转换 
    61.  
    62. -- 加载pretrained VGG16 model 
    63. require 'nn' 
    64. require 'loadcaffe' 
    65. local function getPretrainedModel() 
    66. local proto = '/home/zwzhou/modelZoo/VGG_ILSVRC_16_layers_deploy.prototxt' 
    67. local caffemodel = '/home/zwzhou/modelZoo/VGG_ILSVRC_16_layers.caffemodel' 
    68. local VGG16 = loadcaffe.load(proto,caffemodel,'nn'
    69. for i = 1,3 do 
    70. VGG16.modules[#VGG16.modules]=nil 
    71. end 
    72. return VGG16 
    73. end 
    74.  
    75. -- 为了能够使用VGG,需要定义一些预处理方法 
    76. local loadSize = {3,256,256
    77. local sampleSize={3,224,224
    78.  
    79. local function adjustScale(input) -- VGG需要先将输入图片的最小边缩放到256,另一边保持纵横比 
    80. if input:size(3) < input:size(2) then 
    81. input = image.scale(input,loadSize[2],loadSize[3]*input:size(2)/input:size(3)) 
    82. else 
    83. input = image.scale(input,loadSize[2]*input:size(3)/input:size(2),loadSize[3]) 
    84. end 
    85. return input 
    86. end 
    87.  
    88. local bgr_means = {103.939,116.779,123.68} -- VGG使用的均值,注意是BGR通道,image.load()获得的是rgb 
    89. local function vggProcessing(img) 
    90. local img2 = img:clone() -- 深度拷贝 
    91. img2[{{1}}] = img[{{3}}] 
    92. img2[{{3}}] = img[{{1}}] -- rgb -> bgr 
    93. img2=img2:mul(255
    94. for i=1,3 do 
    95. img2[i]:add(-bgr_means[i]) 
    96. end 
    97. return img2 
    98. end 
    99.  
    100. local function centerCrop(input) -- 截取224*224大小 
    101. local oH = sampleSize[2
    102. local oW = sampleSize[3
    103. local iW = input:size(3
    104. local iH = input:size(2
    105. local w1 = math.ceil((iW-oW)/2
    106. local h1 = math.ceil((iH-oH)/2
    107. local out = image.crop(input,w1,h1,w1+oW,h1+oH) 
    108. return out 
    109. end 
    110.  
    111. local file,_ = io.open('vgg_info.t7'
    112. local vgg_info={} 
    113. if not file then 
    114. local VGG16_model = getPretrainedModel() 
    115. if data_type:match'torch.Cuda.*Tensor' then 
    116. require 'cudnn' 
    117. require 'cunn' 
    118. cudnn.convert(VGG16_model,cudnn):cuda() 
    119. cudnn.benchmark = true 
    120. end 
    121. cast(VGG16_model) 
    122. for i=1, #gt_info do 
    123. local idx=gt_info[i] 
    124. local img = imgs[idx[1]] 
    125. local x1 = math.max(idx[3],1
    126. local y1 = math.max(idx[4],1
    127. local x2 = math.min(idx[3]+idx[5],width) 
    128. local y2 = math.min(idx[4]+idx[6],height) 
    129. local patch = image.crop(img,x1,y1,x2,y2) 
    130. patch = adjustScale(patch) 
    131. patch = vggProcessing(patch) 
    132. patch = centerCrop(patch) 
    133. patch=cast(patch) 
    134. table.insert(vgg_info,VGG16_model:forward(patch):float()) 
    135. end 
    136. torch.save('vgg_info.t7',vgg_info) 
    137. else  
    138. vgg_info=torch.load('vgg_info.t7'
    139. end 
    140.  
    141. local function getPatchPair(tmp) -- 获得patch 对 
    142. local pp = {} 
    143. pp[1] = vgg_info[tmp[1]] 
    144. pp[2] = vgg_info[tmp[2]] 
    145. local t=torch.cat(pp[1],pp[2],1
    146. return
    147. end 
    148.  
    149. -- 定义datasetiterator 
    150. local tnt=require'torchnet' 
    151. local function getIterator(mode) 
    152. -- 创建model 
    153. local fc = nn.Sequential() 
    154. fc:add(nn.View(-1,4096*2)) 
    155. fc:add(nn.Linear(4096*2,500)) 
    156. fc:add(nn.ReLU(true)) 
    157. fc:add(nn.Normalize(2)) 
    158. fc:add(nn.Linear(500,500)) 
    159. fc:add(nn.ReLU(true)) 
    160. fc:add(nn.Linear(500,1)) 
    161.  
    162. -- print(fc:forward(torch.randn(2,4096*2))) 
    163. if data_type:match'torch.Cuda.*Tensor' then 
    164. require 'cudnn' 
    165. require 'cunn' 
    166. cudnn.convert(fc,cudnn):cuda() 
    167. cudnn.benchmark = true 
    168. end 
    169. cast(fc) 
    170.  
    171. -- 构建训练引擎,使用OptimEngine 
    172. require 'optim' 
    173. local engine = tnt.OptimEngine() 
    174. local criterion = cast(nn.MarginCriterion()) 
    175.  
    176. -- 创建一些评估值 
    177. local train_timer = torch.Timer() 
    178. local test_timer = torch.Timer() 
    179. local data_timer = torch.Timer() 
    180.  
    181. local meter = tnt.AverageValueMeter() -- 用于统计评估函数的输出 
    182. local confusion = optim.ConfusionMatrix(2) -- 2类混淆矩阵 
    183. local data_time_meter = tnt.AverageValueMeter() 
    184. -- log 
    185. local logtext=require 'torchnet.log.view.text' 
    186. log = tnt.Log{ 
    187. keys = {'train_loss','train_acc','data_loading_time','epoch','test_acc','train_time','test_time'}, 
    188. onFlush={ 
    189. logtext{keys={'train_loss','train_acc','data_loading_time','epoch','test_acc','train_time','test_time'}} 


    190.  
    191. local inputs = cast(torch.Tensor()) 
    192. local targets = cast(torch.Tensor()) 
    193.  
    194. -- 填一些hook函数,以便观察训练过程 
    195. engine.hooks.onSample = function(state) 
    196. if state.training then 
    197. data_time_meter:add(data_timer:time().real) 
    198. end 
    199. inputs:resize(state.sample.input:size()):copy(state.sample.input) 
    200. targets:resize(state.sample.target:size()):copy(state.sample.target) 
    201. state.sample.input = inputs 
    202. state.sample.target = targets 
    203. end 
    204.  
    205. engine.hooks.onForwardCriterion = function(state) 
    206. meter:add(state.criterion.output) 
    207. confusion:batchAdd(state.network.output:gt(0):add(1),state.sample.target:gt(0):add(1)) 
    208. end 
    209.  
    210. local function test() -- 用于测试 
    211. engine:test{ 
    212. network = fc, 
    213. iterator = getIterator('test'), 
    214. criterion=criterion,  

    215. confusion:updateValids() 
    216. end 
    217.  
    218. engine.hooks.onStartEpoch = function(state) 
    219. local epoch = state.epoch + 1 
    220. print('===>' .. ' online epoch # ' .. epoch .. '[batchsize = 256]'
    221. meter:reset() 
    222. confusion:zero() 
    223. train_timer:reset() 
    224. data_time_meter:reset() 
    225. end 
    226.  
    227. engine.hooks.onEndEpoch = function(state) 
    228. local train_loss = meter:value() 
    229. confusion:updateValids() 
    230. local train_acc = confusion.totalValid*100 
    231. local train_time = train_timer:time().real 
    232. meter:reset() 
    233. print(confusion) 
    234. confusion:zero() 
    235. test_timer:reset() 
    236.  
    237. local cache = state.params:clone() -- 保存现场 
    238. --state.params:copy(state.optim.ax) 
    239. test() 
    240. --state.params:copy(cache) -- 恢复现场 
    241.  
    242. log:set{ 
    243. train_loss = train_loss, 
    244. train_acc = train_acc, 
    245. data_loading_time = data_time_meter:value(), 
    246. epoch = state.epoch, 
    247. test_acc = confusion.totalValid*100
    248. train_time = train_time, 
    249. test_time = test_timer:time().real, 

    250. log:flush() 
    251. end 
    252.  
    253. engine.hooks.onUpdate = function(state) 
    254. data_timer:reset() 
    255. end 
    256.  
    257. engine:train{ 
    258. network = fc, 
    259. criterion = criterion, 
    260. iterator = getIterator('train'), 
    261. optimMethod = optim.sgd, 
    262. config = {learningRate = 0.05
    263. --weightDecay = 0.05, 
    264. momentum = 0.9
    265. --t0 = 1e+4, 
    266. --eta0 =0.1 
    267. }, 
    268. maxepoch = 30,  

    269.  
    270. -- 保存模型 
    271. local modelpath = 'SiaVGG16_model.t7' 
    272. print('Saving to ' .. modelpath) 
    273. torch.save(modelpath,fc:float():clearState()) 
    274. --]] 

    输出:

    enter description here

    1493386765674.jpg

    enter description here

    发现网络太容易过拟合,主要一方面是数据太少,另一方面是视频中就那么几个人,所以patch之间的相关性太大,对网络提供的信息太少。所以使用更多的数据测试结果应该会好许多。

    这个代码主要是为了熟悉torchnet package,感受呢,

    1. 对于数据的预处理,确实方便多了

    2. 如果使用提供的Engine,虽然训练过程简单了但是也太模块化了,比如某些层的微调,比如每层设置不同的学习率

    3. 使用Iterator时,尤其要小心

  • 相关阅读:
    【欧拉质数筛选法 模版】
    【归并排序 逆序对 模版】
    【 lca倍增模板】
    【LSGDOJ 1333】任务安排 dp
    【NOIP2013】火柴排队
    【USACO Feb 2014】Cow Decathlon
    【USACO08NOV】奶牛混合起来Mixed Up Cows
    【LSGDOJ 1351】关灯
    【USACO】干草金字塔
    【USACO】电子游戏 有条件的背包
  • 原文地址:https://www.cnblogs.com/YiXiaoZhou/p/6783448.html
Copyright © 2011-2022 走看看