zoukankan      html  css  js  c++  java
  • NiftyNet 项目了解

    1. NiftyNet项目概述

    NiftyNet项目对tensorflow进行了比较好的封装,实现了一整套的DeepLearning流程。将数据加载、模型加载,网络结构定义等进行了很好的分离,抽象封装成了各自独立的模块。虽然抽象的概念比较多,使得整个项目更为复杂,但是整体结构清晰,支持的模块多。可扩展性还没有进行试验,暂时不是很清楚。 该项目能够实现:

    1. 图像分割
    2. 图像分类
    3. gan
    4. Autoencoder
    5. 回归

    项目支持医学图像的读取,提供的读取器有:

    1. nibabel 支持.nii医学文件格式
    2. simpleitk 支持.dcm和.mhd格式的医疗图像
    3. opencv 支持.jpg等常见图像读取,读取后通道顺序为BGR
    4. skimage 支持.jpg等常见图像读取
    5. pillow 支持.jpg等常见图像读取

    在使用中遇到了一些问题,其训练的速度非常慢。最开始单个iter的平均训练时间估计在40秒以上,有的iter时间会有200秒。现在主要在查找性能瓶颈。

    一、       项目结构

    niftynet.engine.application_driver(ApplicationDriver)定义并驱动着整个Application的生命周期,将配置数据进行解析后,实例化Application并启动流程。

    i.              Application

    Application 作为核心概念,承担整个train或inference的主要功能。所有Application继承于niftynet.application.base_application(简称为BaseApplication)。BaseApplication使用单例模式。

    在Application类中,构建了Tensorflow的图结构和创建Session用于驱动计算。
    BaseApplication单例模式的具体实现有一点小问题。

    Application所完成的工作具体可以划分成以下4个环节

    1. 输入数据相关 数据加载,数据增强,数据取样等,抽象在这两个接口中在SegmentationApplication中,sampler支持:uniform, weighted, resized, balanced4种方式

          initialise_dataset_loader()
    initialise_sampler()

    1. 网络结构相关 网络结构的定义,参数的管理,自定义操作等,抽象在此接口中

          initialise_network()

    1. 模型共享相关 完成由网络的输入到网络的输出,计算loss、gradient,创建optimizer等,抽象在此接口中

          connect_data_and_network()

    1. 输出解码相关 inference将网络输出解码操作,抽象在此接口中

          interpret_output()

    ii.              Config

    配置文件需要必须包含的模块:

    • [SYSTEM]
    • [NETWORK]
    • 如果action为train,那么config中需要包含[TRAINING]模块
    • 如果action为inference,那么config中需要包含[INFERENCE]模块
    • 额外的,根据特定的application,会需要包含指定名称的模块。如:

    –             [GAN]

    –             [SEGMENTATION]

    –             [REGRESSION]

    –             [AUTOENCODER]

    • 除了以上的配置外,其他的数据会处理为input data source specifications【数据声明模块】
    l  数据声明模块

    Name

    解释

    例子

    默认值

    csv_file

    包含输入图像文件的列表

    csvfile=filelist.csv

    ''

    pathtosearch

    如果没有配置csv_file,则从此路径下去搜索输入图像

    pathtosearch=~/ct_data

    NiftyNet home folder

    filename_contains

    搜索输入图像时用于匹配的关键词

    filename_contains=foo, bar

    ''

    filenamenotcontains

    搜索输入图像时用于排除的关键词

    filenamenotcontains=ti, s1

    ''

    filename_removefromid

    正则表达式,用于从输入图像的文件名中,解析出id

    filename_removefromid=foo

    ''

    interp_order

    插值法

    interp_order=1

    3

    pixdim

    如果指定了,输入的3D图像会重新采样到指定大小再送入网络

    pixdim=1.2, 1.2, 1.2

    ''

    axcodes

    如果指定了,输入的3D图像会重新设定到指定的axcodes顺序再送入网络 参考文章

    axcodes=L, P, S

    ''

    spatialwindowsize

    3个整数,指定输入window的大小[能被8整除]

    spatialwindowsize=64, 64, 64

    ''

    loader

    指定图像读取loader类型

    loder=simpleitk

    None

    [interp_order]  当设定采样方法为resize时,需要这个参数对图片上采样或下采样 1表示双线性插值
    0表示最近邻插值
    3表示三次样条插值

    l  [SYSTEM]

    Name

    解释

    例子

    默认值

    cude_devices

    指定GPU

    cuda_devices=0,1

    ''

    num_threads

    预处理线程的数量

    num_threads=8

    2

    num_gpus

    训练时使用GPU数量

    num_gpus=2

    1

    model_dir

    保存或读取模型权重和Log的位置

    model_dir=~/niftynet/xxx

    config文件所在目录

    datasetsplitfile

    用于将数据划分成training/validation/inferenct字集

    datasetsplitfile=~/nifnet/xxx

    ./datasetsplitfile.csv

    event_handler

    注册事件处理

    eventhandler=modelrestorer

    modelsaver, modelrestorer, samplerthreading, applygradients, outputinterpreter, consolelogger, tensorboard_logger

    l  [NETWORK]

    Names

    解释

    例子

    默认值

    name

    所使用的网络结构

    name=niftynet.network.toynet.ToyNet

    ‘’

    activation_function

    设置网络中使用的激活函数

    activation_function=prelu

    Relu

    batch_size

    批大小

    batch_size=10

    2

    smaller_final_batch_mode

    当总数据量不能被batch_size整除时,最后一个batch_size的方式

    smaller_final_batch_mode=drop

    smaller_final_batch_mode=pad

    smaller_final_batch_mode=dynamic

    pad

    decay

    正则化参数

    decay=1e-5

    0.0

    reg_type

    正则化类型

    reg_type=L1

    L2

    volume_padding_size

    volume_padding_size=4, 4, 4

    0, 0, 0

    volume_padding_mode

    volume_padding_mode=symmetric

    minimum

    window_sampling

    采样的类型

    window_sampling=uniform

    固定尺寸,相同的概率分布

    window_sampling=weighted

    固定尺寸,根据intensity作为概率分布

    window_sampling=balanced

    固定尺寸,每个label拥有相同采样概率

    window_sampling=resize

    缩放图像到window尺寸

    uniform

    queue_length

    采样时使用的buffer大小

    queue_length=10

    5

    keep_prob

    如果网络中使用了dropout

    keep_prob=0.2

    1.0

    l  [TRAINING]

    Name

    解释

    例子

    默认值

    optimizer

    优化器类型

    optimizer=momentum

    adam

    sample_per_volume

    每个输入图像采样的次数

    sample_per_volume=5

    1

    lr

    学习率

    lr=0.0001

    0.1

    loss_type

    loss计算方式

    loss_type=CrossEntropy

    Dice

    starting_iter

    启动的iter

    starting_iter=0

    0

    save_every_n

    保存的间隔

    save_every_n=50

    500

    tensorboard_every_n

    tensorboard记录的间隔

    tensorboard_every_n=50

    20

    max_iter

    最大iter数

    max_iter=3000

    10000

    max_checkpoints

    保存的最多checkpoint数

    max_checkpoints=5

    100

    训练时验证

    validation_every_n

    训练时进行验证的间隔

    validation_every_n=10

    -1

    validation_max_iter

    验证时iter的数量

    validation_max_iter=5

    1

    exclude_fraction_for_validation

    验证集的比重

    exclude_fraction_for_validation=0.2

    0.0

    exclude_fraction_for_inference

    测试集的比重

    exclude_fraction_for_inference=0.1

    0.0

    数据增强

    rotation_angle

    旋转

    rotation_angle=-10.0, 10.0

    ‘’

    scaling_percentage

    缩放

    scaling_percentage=-20.0, 20.0

    ‘’

    random_flipping_axes

    翻转

    random_flipping_axes=1,2

    -1

    l  [INFERENCE]

    Name

    解释

    例子

    默认值

    spatial_window_size

    网络输入尺寸大小

    spatial_window_size=64,64,64

    ‘’

    border

    输入尺寸的边框

    border=5,5,5

    0,0,0

    inference_iter

    使用指定iter保存的权重文件

    inference_iter=1000

    -1

    save_seg_dir

    保存输出路径

    save_seg_dir=output/test

    output

    output_postfix

    输出保存的后缀

    output_postfix=_output

    _niftynet_out

    output_interp_order

    插值法

    output_interp_order=0

    0

    dataset_to_infer

    使用的数据集,可选:’all’, ‘training’, ‘validation’, ‘inference’

    dataset_to_infer=all

    ‘’

    iii.              Reader & Dataset

    n  niftynet.io.image_reader模块
    ImageReader的主要作用是,遍历一组目录,搜索并返回一个图像的列表,以及使用iterative的方式将数据加载到内存中。
    ImageReader会创建一个tf.data.Dataset的对象,这样使得模块可以很方便地接入到基于tensorflow的程序中。
    ImageReader的特点:

    l  设计用于支持医疗图像数据的格式

    l  支持多模态输入数据

    l  支持tf.data.Dataset

    n  niftynet.contrib.dataset_sampler
    sampler将 image reader作为输入,从每张图像中采取出结果输出。
    在很多的医学图像处理的情况中,由于GPU显存的限制以及训练效率等的考虑,网络结构会对图像的部分进行处理而非整张图像。

    iv.              Network

    项目中包含了一些已经实现的网络:

    1. GAN:

    –             simulator_gan

    –             siple_gan

    1. Segmentation:

    –             highres3dnet, highres3dnetsmall, highres3dnetlarge

    –             toynet

    –             unet

    –             vnet

    –             dense_vnet

    –             deepmedic

    –             scalenet

    –             holisticnet

    –             unet_2d

    1. classification:

    –             resnet

    –             se_resnet

    1. autoencoder:

    –             vae

    v.              Loss

    已提供支持的loss计算方式

    1. Segmentation
      1. CrossEntropy
      2. CrossEntropy_Dense
      3. Dice
      4. Dice_NS
      5. Dice_Dense
      6. Dice_Dense_NS
      7. Tversky
      8. GDSC
      9. WGDL
      10. SensSpec
      11. Gan
        1. CrossEntropy
        2. Regression
          1. L1Loss
          2. L2Loss
          3. RMSE
          4. MAE
          5. Huber
          6. Classification
            1. CrossEntropy
            2. AutoEncoder
              1. VariationalLowerBound

    支持的优化器类型

    1. adam
    2. gradientdescent
    3. momentum
    4. nesterov
    5. adagrad
    6. rmsprop

    vi.              Event机制

    NiftyNet项目的设计,使用了Signal和event handler模式,具体实现使用了blinker库。这样可以方便地将模型保存,tensorboard记录等操作进行配置。

    目前可供注册的signal有:

    1. GRAPH_CREATED
    2. SESS_STARTED
    3. SESS_FINISHED
    4. ITER_STARTED
    5. ITER_FINISHED

    信号处理函数注册到对应的信号后,由引擎负责调用。

    vii.              Layer

    网络层的相关设计都封装在Layer类中,可继承layer类,实现定制化结构

  • 相关阅读:
    UILabel 设置字体间的距离 和 行与行间的距离
    IB_DESIGNABLE 和 IBInspectable 的使用
    干货博客
    GitHub克隆速度太慢解决方案
    实时(RTC)时钟,系统时钟和CPU时钟
    折腾了好久的vscode配置c/c++语言环境(Windows环境下)
    c语言中的malloc函数
    记录一下关于在工具类中更新UI使用RunOnUiThread犯的极其愚蠢的错误
    记录关于Android多线程的一个坑
    Android中限制输入框最大输入长度
  • 原文地址:https://www.cnblogs.com/bicker/p/9753305.html
Copyright © 2011-2022 走看看