zoukankan      html  css  js  c++  java
  • DeepSurv: Personalized Treatment Recommender System Using A Cox Proportional Hazards Deep Neural Network(理解)(github代码)

    (代码托管在我的Github上,如果有帮助记得点星星嗨!)

    0 - 背景

    0.0 - 概要

    生存预测模型探索的是患者的各个属性/特征与治疗效果之间的关系。之前的生存预测模型,像linear Cox proportional hazards model需要有专业的医学知识作为专业背景来构建特征工程,而另外的一些nonlinear survival methods,像neural networks/survival forests则没有在有效的推荐系统中得到实践证明。文中提出一种Cox proportional hazards deep neural network的生存模型DeepSurv,并且在模拟数据集及临床数据集上进行实验,证明了模型具有可比的或者最好的性能,此外还将该DeepSurv应用到治疗效果推荐系统上提供个性化推荐。

    0.1 - 相关概念

    生存数据:一般生存数据由三部分组成:患者的基线数据$x$,死亡事件时间$T$,事件指标$E$。如果死亡事件发生了,则$T$代表的就是基线数据$x$与死亡事件发生之间的时间间隔,此时$E=1$。如果死亡事件没有发生,则$T$表示基线数据$x$与患者最后一次采集数据的时间间隔,此时$E=0$,这部分数据称为右删失(right-censored)。

    生存函数(Survival Function):生存函数可以定义为$S(t)=Pr(T>t)$,其表示的是个体在时刻$t$生存的概率,其可以通过下式进行估计,

    $$hat{S}(t)=frac{number of patients surviving longer than t}{total number of patients}.$$

    密度函数(Density Function):密度函数可以定义为$f(t)=lim_{delta ightarrow 0}frac{Prleft(tleq T< t+delta | Tgeq t ight )}{delta}$,其表示已经处在生存时间$T$的短暂时刻发生事件的概率,其估计方法为,

    $$hat{f}(t)=frac{number of patients dying in the interval beginning at time t}{(total number of patients) imes(interval width)}.$$

    风险函数(Hazard Function):风险函数用来衡量当前个体在时刻$t$之前没有发生任何事件的情况下,时刻$t$发生事件的概率,其可以定义为$lambda(t)=lim_{delta ightarrow 0}frac{Prleft(tleq T< t+delta ight)}{delta}$,其估计方法为,

    $$hat{lambda}(t)=frac{number of patients dying in the interval beginning at time t}{(number of patients surviving at t) imes(interval width)}.$$

    Cox比例风险回归模型:Cox比例风险回归模型是一种常用的方法,用于在给定基线数据$x$的情况下对个体的生存风险进行建模。该模型由两部分组成:只与时间相关的基线风险函数$lambda_0(t)$和只与患者数据$x$相关的函数$h(x)$。该模型表示为$lambda(t|x)=lambda_0(t)cdot e^{h(x)}.$

    C-index:这是生存预测的一个评价指标,英文全称为concordance index,因为对于存在删失的生存数据,一些标准的评估方法,例如均方误差等,是不合适的。其计算方式是,(1)将所有样本两两配对,例如有$N$个样本,则一共可以组成$N imes (N-1)/2$对;(2)排除其中无法判断谁先出现感兴趣事件的配对(两个实例都没有发生事件),得到剩余的对数$M$;(3)在剩下的$M$对中,预测结果与实际结果一致的配对数$K$,即预测的生存$S(X_A)<S(X_B)$(或者说风险率$R(X_A)>R(X_B)$),实际的$T_A<T_B$,即为一致;(4)则$C-index=frac{K}{M}$。其可以形式化为如下公式,

    $$frac{1}{M}sum_{i:delta_i=1}sum_{j:T_i<T_j}Ileft[S(T_i,X_i)<S(T_j,X_j) ight ],$$

    其中$I[C]$表示若$C$为真,则$I[C]=1$,否则$I[C]=0$。$delta_i=1$表示至少要有一个实例发生了事件,$T_i<T_j$表示对$i$和$j$配对的要求,即防止$i$和$j$颠倒算了两次。

    0.2 - Linear Survial Models

    线性生存模型是把cox模型中的$h(x)$采用线性函数$hat{h}_{eta}(x)=eta^Tx$进行建模,可以定义为,

    $$L_c(eta)=prod_{i:E_i=1}frac{exp(hat{h}_{eta}(x_i))}{sum_{jin Re(T_i)}exp(hat{h}_{eta}(x_j))},$$

    其中$T_i,E_i,x_i$分别表示事件事件、事件指标、第$i$个基准数据。上述式子定义在一组可观察到事件发生的患者上$E_i=1$,风险集合$Re (t)={i:T_igeq t}$表示在时刻$t$仍然处于风险的患者集合。

    0.3 - NonLinear Survial Models

    即$hat{h}_{ heta}(x)$由非线性模型进行建模。

    1 - 方法

    1.0 - DeepSurv

    DeepSurv是一个多层感知机,模型的预测输出是一个值,代表患者的健康风险,其损失函数定义为,

    $$l( heta):=-sum_{i:E_i=1}left(hat{h}_{ heta}(x_i)-logsum_{jinRe(T_i)}e^{hat{h}_{ heta}(x_j)} ight ),$$

    文中将DeepSurv设计成了一个深度结构(可能有多层隐藏层),并且加入了权重衰减正则化、ReLU激活、batch normalization、SELU、dropout、SGD、Adam、梯度裁剪、学习率调整策略等当时比较新的技术。

    1.1 - 治疗推荐系统

    在一项临床研究中,患者根据其相关的预后特征和所接受的治疗具有不同程度的风险。文中把这个假设概括为,假设研究中的患者被分到$n$个治疗组$ au in {0,1,cdots,n-1}$中的一个,每一个治疗方案$i$具有独立的风险函数$h_i(x)$。总的来说,风险函数变成了,

    $$lambda(t;x| au=i)=lambda_0(t)cdot e^{h_i(x)},$$

    基于上述的假设,每一个个体拥有一样的初始风险函数$lambda_0(t)$,我们可以用采用不同资料方案的风险率的对数来对比同一个体接受两种治疗方案的对比,其推导为,

    $$rec_{ij}(x)=logleft(frac{lambda(t;x| au=i)}{lambda(t;x| au=j)} ight )=logleft(frac{lambda_0(t)cdot e^{h_i(x)}}{lambda_0(t)cdot e^{h_j(x)}} ight )=h_i(x)-h_j(x),$$

    如果$rec_{ij}>0$,则说明$i$方案比$j$方案风险高,应该选择$j$方案,反之则反。

    2 - 结果

    我复现了文章中第4节在几个数据集上的结果,模型和训练的参数有稍微的调整,参数配置可以自己在配置文件里面修改,结果如下表所示。(代码托管在我的Github上,如果有帮助记得点星星嗨

      Simulated Linear Simulated Nonlinear WHAS SUPPORT METRABRIC Simulated Treatment Rotterdam & GBSG
    Paper 0.774019 0.648902 0.862620 0.618308 0.643374 0.582774 0.668402
    Ours 0.778607 0.652048 0.841484 0.618107 0.643453 0.552648 0.673290

    3 - 参考资料

    https://github.com/czifan/DeepSurv.pytorch

    https://link.springer.com/article/10.1186/s12874-018-0482-1

    https://github.com/jaredleekatzman/DeepSurv

  • 相关阅读:
    c++Primer再学习(1)
    c++Primer再学习练习Todo
    感悟(一)
    新目标《C++程序设计原理与实践》
    C++Primer再学习(4)
    开篇
    C++Primer再学习(3)
    C++实现的单例模式的解惑
    使用springboot缓存图片
    springboot h2 database
  • 原文地址:https://www.cnblogs.com/CZiFan/p/12674144.html
Copyright © 2011-2022 走看看