早就怀疑过MXNet的pack,但是那次验证(把图片打出来对比)出来的结论是没问题。
结果这次就吊在这上面了:
问题
- 先用im2rec.py打包图片;
- 然后继承iamge.imageIter,进行改造,rec读取和解码没有修改;
结果发现网络的输出不时就会出现异常值(其中一个迭代器输出Mask),最后定位到迭代器里面:
# 继承自 image.imageIter
# 此迭代器 输出 Mask Map !
# next() 节选
try:
while i < batch_size:
label, s = self.next_sample()
data = [imdecode(s)] # POINT B
if len(data[0].shape) == 0:
# logging.debug('Invalid image, skipping.')
continue
for aug in self.auglist:
data = [ret for src in data for ret in aug(src)]
for d in data:
assert i < batch_size, 'Batch size must be multiples of augmenter output length'
dTmp=nd.transpose(d, axes=(2, 0, 1))
if self.data_shape[0] == 1: # only keep channel 0
dTmp=dTmp[0]/255 # 此处是 BUG 相关点 POINT A
batch_data[i][:] = dTmp.reshape((1,)+dTmp.shape)
# batch_label[i][:] = label
i += 1
使用以上这段程序,检查迭代器输出(经数值反变换)来的图没问题,但打印网络的输出就发现明显异常。
如果把POINT A的程序换为:
# 替换 POINT A 处
dTmp = dTmp[0]
dTmp[dTmp>0]= 1
网络输出值没有异常,但迭代器输出图像就会出现大面积阴影。
原因
问题就出在解码上,在 POINT B处检车一下,就会有见鬼的发现:
# 放置于 POINT B 后
if config.debug.it_check:
dTmp=data[0].asnumpy()
logging.debug('num in (0,255):%d'%( ( (dTmp>0)*(dTmp<255) ).sum() ) ) # dTmp 应当是 Mask
混乱了。。。orz
Solution
快要爆发的时候,顺手把图打印出来了,发现波动的值大致在250左右,于是改了下POINT A处:
# 替换 POINT A 处
dTmp=dTmp[0]
dTmp[dTmp>250]=1
dTmp[dTmp<250]=0
视觉上,和网络输出正常化。
勉强用着吧。
Followup
这样的解码似乎不太可靠,后续要考虑替代方案。