1. NiftyNet项目概述
NiftyNet项目对tensorflow进行了比较好的封装,实现了一整套的DeepLearning流程。将数据加载、模型加载,网络结构定义等进行了很好的分离,抽象封装成了各自独立的模块。虽然抽象的概念比较多,使得整个项目更为复杂,但是整体结构清晰,支持的模块多。可扩展性还没有进行试验,暂时不是很清楚。 该项目能够实现:
- 图像分割
- 图像分类
- gan
- Autoencoder
- 回归
项目支持医学图像的读取,提供的读取器有:
- nibabel 支持.nii医学文件格式
- simpleitk 支持.dcm和.mhd格式的医疗图像
- opencv 支持.jpg等常见图像读取,读取后通道顺序为BGR
- skimage 支持.jpg等常见图像读取
- pillow 支持.jpg等常见图像读取
在使用中遇到了一些问题,其训练的速度非常慢。最开始单个iter的平均训练时间估计在40秒以上,有的iter时间会有200秒。现在主要在查找性能瓶颈。
一、 项目结构
niftynet.engine.application_driver(ApplicationDriver)定义并驱动着整个Application的生命周期,将配置数据进行解析后,实例化Application并启动流程。
i. Application
Application 作为核心概念,承担整个train或inference的主要功能。所有Application继承于niftynet.application.base_application(简称为BaseApplication)。BaseApplication使用单例模式。
在Application类中,构建了Tensorflow的图结构和创建Session用于驱动计算。
BaseApplication单例模式的具体实现有一点小问题。
Application所完成的工作具体可以划分成以下4个环节
- 输入数据相关 数据加载,数据增强,数据取样等,抽象在这两个接口中在SegmentationApplication中,sampler支持:uniform, weighted, resized, balanced4种方式
initialise_dataset_loader()
initialise_sampler()
- 网络结构相关 网络结构的定义,参数的管理,自定义操作等,抽象在此接口中
initialise_network()
- 模型共享相关 完成由网络的输入到网络的输出,计算loss、gradient,创建optimizer等,抽象在此接口中
connect_data_and_network()
- 输出解码相关 inference将网络输出解码操作,抽象在此接口中
interpret_output()
ii. Config
配置文件需要必须包含的模块:
- [SYSTEM]
- [NETWORK]
- 如果action为train,那么config中需要包含[TRAINING]模块
- 如果action为inference,那么config中需要包含[INFERENCE]模块
- 额外的,根据特定的application,会需要包含指定名称的模块。如:
– [GAN]
– [SEGMENTATION]
– [REGRESSION]
– [AUTOENCODER]
- 除了以上的配置外,其他的数据会处理为input data source specifications【数据声明模块】
l 数据声明模块
Name |
解释 |
例子 |
默认值 |
csv_file |
包含输入图像文件的列表 |
csvfile=filelist.csv |
'' |
pathtosearch |
如果没有配置csv_file,则从此路径下去搜索输入图像 |
pathtosearch=~/ct_data |
NiftyNet home folder |
filename_contains |
搜索输入图像时用于匹配的关键词 |
filename_contains=foo, bar |
'' |
filenamenotcontains |
搜索输入图像时用于排除的关键词 |
filenamenotcontains=ti, s1 |
'' |
filename_removefromid |
正则表达式,用于从输入图像的文件名中,解析出id |
filename_removefromid=foo |
'' |
interp_order |
插值法 |
interp_order=1 |
3 |
pixdim |
如果指定了,输入的3D图像会重新采样到指定大小再送入网络 |
pixdim=1.2, 1.2, 1.2 |
'' |
axcodes |
如果指定了,输入的3D图像会重新设定到指定的axcodes顺序再送入网络 参考文章 |
axcodes=L, P, S |
'' |
spatialwindowsize |
3个整数,指定输入window的大小[能被8整除] |
spatialwindowsize=64, 64, 64 |
'' |
loader |
指定图像读取loader类型 |
loder=simpleitk |
None |
[interp_order] 当设定采样方法为resize时,需要这个参数对图片上采样或下采样 1表示双线性插值
0表示最近邻插值
3表示三次样条插值
l [SYSTEM]
Name |
解释 |
例子 |
默认值 |
cude_devices |
指定GPU |
cuda_devices=0,1 |
'' |
num_threads |
预处理线程的数量 |
num_threads=8 |
2 |
num_gpus |
训练时使用GPU数量 |
num_gpus=2 |
1 |
model_dir |
保存或读取模型权重和Log的位置 |
model_dir=~/niftynet/xxx |
config文件所在目录 |
datasetsplitfile |
用于将数据划分成training/validation/inferenct字集 |
datasetsplitfile=~/nifnet/xxx |
./datasetsplitfile.csv |
event_handler |
注册事件处理 |
eventhandler=modelrestorer |
modelsaver, modelrestorer, samplerthreading, applygradients, outputinterpreter, consolelogger, tensorboard_logger |
l [NETWORK]
Names |
解释 |
例子 |
默认值 |
name |
所使用的网络结构 |
name=niftynet.network.toynet.ToyNet |
‘’ |
activation_function |
设置网络中使用的激活函数 |
activation_function=prelu |
Relu |
batch_size |
批大小 |
batch_size=10 |
2 |
smaller_final_batch_mode |
当总数据量不能被batch_size整除时,最后一个batch_size的方式 |
smaller_final_batch_mode=drop smaller_final_batch_mode=pad smaller_final_batch_mode=dynamic |
pad |
decay |
正则化参数 |
decay=1e-5 |
0.0 |
reg_type |
正则化类型 |
reg_type=L1 |
L2 |
volume_padding_size |
volume_padding_size=4, 4, 4 |
0, 0, 0 |
|
volume_padding_mode |
volume_padding_mode=symmetric |
minimum |
|
window_sampling |
采样的类型 |
window_sampling=uniform 固定尺寸,相同的概率分布 window_sampling=weighted 固定尺寸,根据intensity作为概率分布 window_sampling=balanced 固定尺寸,每个label拥有相同采样概率 window_sampling=resize 缩放图像到window尺寸 |
uniform |
queue_length |
采样时使用的buffer大小 |
queue_length=10 |
5 |
keep_prob |
如果网络中使用了dropout |
keep_prob=0.2 |
1.0 |
l [TRAINING]
Name |
解释 |
例子 |
默认值 |
optimizer |
优化器类型 |
optimizer=momentum |
adam |
sample_per_volume |
每个输入图像采样的次数 |
sample_per_volume=5 |
1 |
lr |
学习率 |
lr=0.0001 |
0.1 |
loss_type |
loss计算方式 |
loss_type=CrossEntropy |
Dice |
starting_iter |
启动的iter |
starting_iter=0 |
0 |
save_every_n |
保存的间隔 |
save_every_n=50 |
500 |
tensorboard_every_n |
tensorboard记录的间隔 |
tensorboard_every_n=50 |
20 |
max_iter |
最大iter数 |
max_iter=3000 |
10000 |
max_checkpoints |
保存的最多checkpoint数 |
max_checkpoints=5 |
100 |
训练时验证
validation_every_n |
训练时进行验证的间隔 |
validation_every_n=10 |
-1 |
validation_max_iter |
验证时iter的数量 |
validation_max_iter=5 |
1 |
exclude_fraction_for_validation |
验证集的比重 |
exclude_fraction_for_validation=0.2 |
0.0 |
exclude_fraction_for_inference |
测试集的比重 |
exclude_fraction_for_inference=0.1 |
0.0 |
数据增强
rotation_angle |
旋转 |
rotation_angle=-10.0, 10.0 |
‘’ |
scaling_percentage |
缩放 |
scaling_percentage=-20.0, 20.0 |
‘’ |
random_flipping_axes |
翻转 |
random_flipping_axes=1,2 |
-1 |
l [INFERENCE]
Name |
解释 |
例子 |
默认值 |
spatial_window_size |
网络输入尺寸大小 |
spatial_window_size=64,64,64 |
‘’ |
border |
输入尺寸的边框 |
border=5,5,5 |
0,0,0 |
inference_iter |
使用指定iter保存的权重文件 |
inference_iter=1000 |
-1 |
save_seg_dir |
保存输出路径 |
save_seg_dir=output/test |
output |
output_postfix |
输出保存的后缀 |
output_postfix=_output |
_niftynet_out |
output_interp_order |
插值法 |
output_interp_order=0 |
0 |
dataset_to_infer |
使用的数据集,可选:’all’, ‘training’, ‘validation’, ‘inference’ |
dataset_to_infer=all |
‘’ |
iii. Reader & Dataset
n niftynet.io.image_reader模块
ImageReader的主要作用是,遍历一组目录,搜索并返回一个图像的列表,以及使用iterative的方式将数据加载到内存中。
ImageReader会创建一个tf.data.Dataset的对象,这样使得模块可以很方便地接入到基于tensorflow的程序中。
ImageReader的特点:
l 设计用于支持医疗图像数据的格式
l 支持多模态输入数据
l 支持tf.data.Dataset
n niftynet.contrib.dataset_sampler
sampler将 image
reader作为输入,从每张图像中采取出结果输出。
在很多的医学图像处理的情况中,由于GPU显存的限制以及训练效率等的考虑,网络结构会对图像的部分进行处理而非整张图像。
iv. Network
项目中包含了一些已经实现的网络:
- GAN:
– simulator_gan
– siple_gan
- Segmentation:
– highres3dnet, highres3dnetsmall, highres3dnetlarge
– toynet
– unet
– vnet
– dense_vnet
– deepmedic
– scalenet
– holisticnet
– unet_2d
- classification:
– resnet
– se_resnet
- autoencoder:
– vae
v. Loss
已提供支持的loss计算方式
- Segmentation
- CrossEntropy
- CrossEntropy_Dense
- Dice
- Dice_NS
- Dice_Dense
- Dice_Dense_NS
- Tversky
- GDSC
- WGDL
- SensSpec
- Gan
- CrossEntropy
- Regression
- L1Loss
- L2Loss
- RMSE
- MAE
- Huber
- Classification
- CrossEntropy
- AutoEncoder
- VariationalLowerBound
支持的优化器类型
- adam
- gradientdescent
- momentum
- nesterov
- adagrad
- rmsprop
vi. Event机制
NiftyNet项目的设计,使用了Signal和event handler模式,具体实现使用了blinker库。这样可以方便地将模型保存,tensorboard记录等操作进行配置。
目前可供注册的signal有:
- GRAPH_CREATED
- SESS_STARTED
- SESS_FINISHED
- ITER_STARTED
- ITER_FINISHED
信号处理函数注册到对应的信号后,由引擎负责调用。
vii. Layer
网络层的相关设计都封装在Layer类中,可继承layer类,实现定制化结构