27 | 使用PyTorch完成医疗图像识别大项目:实现端到端模型方案

接下来需要再做一些工作,并把我们前面搞好的模型串起来,形成一个端到端的解决方案。这个方案如下,首先是从原始的CT数据出发进行图像分割,识别可能是结节的体素,并对这些体素区域进行分组,然后用这些分割出的候选结节信息进行分类,首先是区分这是否是一个结节,针对是结节的,再区分这是否是一个恶性结节,这样就完成了整个模型框架。


image.png

由于我们之前训练的两个模型使用的训练数据是不一样的,我们直接获取了标注的结节信息作为分类模型的训练集,而在实际中,我们需要对分割模型的结果进行分类。这就存在数据泄露的问题。也就说在分类模型的训练集中可能有些数据是分割模型的验证集,反过来,在分类模型的验证集里面可能有分割模型的训练集数据。所以之前压根就没保存模型,就是为了在这里重新训练一下。获取LunaDataset跟之前还是一样的,有区别的是从segmentationDataset中获取标注数据并分割为训练集和验证集。

重新训练分类模型

先为这一章构建数据缓存。

run('test14ch.prepcache.LunaPrepCacheApp')

然后训练100个epoch,由于下调了数据样本量,训练集里面的正负样本各2.5w条,验证集保持原样正样本154条,负样本5w+。所以训练起来还算可以,差不多10分钟一个epoch

run('test14ch.training.ClassificationTrainingApp', f'--epochs=100', 'nodule-nonnodule')

看到第一个epoch结果,效果还不是很好,不过程序没问题,就在这里跑着好了,我就去睡觉了。


image.png

一觉醒来,已经70+epoch,训练集上的准确率已经99%+,验证集上对于正样本的准确率也达到了94%+,不过中体的precision还是比较低的,因为两个类别的样本量差距太大了,有很多负样本被归为了阳性结果,不过这问题不大,我们主要是能把真的阳性筛出来就好了。


image.png

到了80个epoch,训练集效果基本没变了,验证集上阳性准确率下降了一点。
image.png

直接跳到100个epoch,可以看到在训练集上的效果又提升了一丢丢,但是验证集上,尤其是验证集的负样本准确率下降了不少,这不符合我们的预期,说明模型有点过拟合了。


image.png

我们就用这里面的最佳模型作为我们最后系统中需要使用的分类模型。(第95epoch的模型,这时候有最好的f1 score)

连接分割和分类模型

新建一个代码,就叫结节分析:nodule_analysis.py,它最核心的地方就是下面这段。

#取一个uid
        for _, series_uid in series_iter:
#然后获取对应的CT数据
            ct = getCt(series_uid)
#紧接着跑分割模型
            mask_a = self.segmentCt(ct, series_uid)
#给分割模型预测到的结节数据进行分组
            candidateInfo_list = self.groupSegmentationOutput(
                series_uid, ct, mask_a)
#最后跑分类模型,决定是不是结节
            classifications_list = self.classifyCandidates(
                ct, candidateInfo_list)

其中,分割部分代码

    def segmentCt(self, ct, series_uid):
#预测不需要更新,关闭自动梯度计算
        with torch.no_grad():
#用来存储输出结果
            output_a = np.zeros_like(ct.hu_a, dtype=np.float32)
#初始化数据加载器
            seg_dl = self.initSegmentationDl(series_uid) 
#遍历整个CT
            for input_t, _, _, slice_ndx_list in seg_dl:
#发送到GPU
                input_g = input_t.to(self.device)
#运行分割模型
                prediction_g = self.seg_model(input_g)
#把结果存起来
                for i, slice_ndx in enumerate(slice_ndx_list):
                    output_a[slice_ndx] = prediction_g[i].cpu().numpy()
#构建掩码结果
            mask_a = output_a > 0.5
            mask_a = morphology.binary_erosion(mask_a, iterations=1)
#返回
        return mask_a

接下来跟分割的结果分组。这里使用了一个scipy.ndimage.measurements的方法,measurements.label用来标记连通区域。举个简单的例子,如下左图是我们的掩码结果,measurements.label的功能就是去识别这里面有多少是连通的,并用一个标记去修改它们的值。对于左上角,是第一个连通区域,那么里面的值都改为1,中间这块是第二个连通区域,里面的值都改为2,依次类推,就变成了右侧的样子。 measurements.center_of_mass则是用来计算每个连通区域的中心点坐标。


image.png
    def groupSegmentationOutput(self, series_uid,  ct, clean_a):

        candidateLabel_a, candidate_count = measurements.label(clean_a)
        centerIrc_list = measurements.center_of_mass(
            ct.hu_a.clip(-1000, 1000) + 1001,
            labels=candidateLabel_a,
            index=np.arange(1, candidate_count+1),
        )
#把识别到的数据转化分类模型要用的数据。
        candidateInfo_list = []
        for i, center_irc in enumerate(centerIrc_list):
            center_xyz = irc2xyz(
                center_irc,
                ct.origin_xyz,
                ct.vxSize_xyz,
                ct.direction_a,
            )
            assert np.all(np.isfinite(center_irc)), repr(['irc', center_irc, i, candidate_count])
            assert np.all(np.isfinite(center_xyz)), repr(['xyz', center_xyz])
            candidateInfo_tup = \
                CandidateInfoTuple(False, False, False, 0.0, series_uid, center_xyz)
            candidateInfo_list.append(candidateInfo_tup)

        return candidateInfo_list

最后是分类模型。

    def classifyCandidates(self, ct, candidateInfo_list):
# 初始化dataloader
        cls_dl = self.initClassificationDl(candidateInfo_list)
        classifications_list = []
        for batch_ndx, batch_tup in enumerate(cls_dl):
            input_t, _, _, series_list, center_list = batch_tup
#发送到GPU上
            input_g = input_t.to(self.device)
            with torch.no_grad():
#运行分类模型
                _, probability_nodule_g = self.cls_model(input_g)
#这里还有一个分是否恶性的模型,现在我们还没开发,先留下这个位置
                if self.malignancy_model is not None:
                    _, probability_mal_g = self.malignancy_model(input_g)
                else:
                    probability_mal_g = torch.zeros_like(probability_nodule_g)

            zip_iter = zip(center_list,
                probability_nodule_g[:,1].tolist(),
                probability_mal_g[:,1].tolist())
#转换坐标
            for center_irc, prob_nodule, prob_mal in zip_iter:
                center_xyz = irc2xyz(center_irc,
                    direction_a=ct.direction_a,
                    origin_xyz=ct.origin_xyz,
                    vxSize_xyz=ct.vxSize_xyz,
                )
                cls_tup = (prob_nodule, prob_mal, center_xyz, center_irc)
                classifications_list.append(cls_tup)
        return classifications_list

我们的CT图像原本有大概3300w个体素,经过图像分割之后,留下大约100w个体素,通过给这些体素分组,可以得到大概1000个候选结节信息,然后对这些信息进行分类确认哪些是结节,哪些不是结节,经过这步之后还剩几十个确认是结节,最后一步是确认结节的性质,恶性的通常来说最多也就一两个。


image.png

这时候在回到main方法中,我们已经得到了模型的结果,

#这个cli_args.run_validation参数是用来判断是否跑验证集数据的,如果不是验证集数据,而是单个输入uid,那么就执行下面的信息显示功能
            if not self.cli_args.run_validation:
                print(f"found nodule candidates in {series_uid}:")
                for prob, prob_mal, center_xyz, center_irc in classifications_list:
                    if prob > 0.5:#如果我们找到的结节概率超过0.5,就输出信息到屏幕上,给医生看
                        s = f"nodule prob {prob:.3f}, "
                        if self.malignancy_model:
                            s += f"malignancy prob {prob_mal:.3f}, "
                        s += f"center xyz {center_xyz}"
                        print(s)
#这里输出混淆矩阵
            if series_uid in candidateInfo_dict:
                one_confusion = match_and_score(
                    classifications_list, candidateInfo_dict[series_uid]
                )
                all_confusion += one_confusion
                print_confusion(
                    series_uid, one_confusion, self.malignancy_model is not None
                )

        print_confusion(
            "Total", all_confusion, self.malignancy_model is not None
        )

用我们建好的模型来预测一下数据看看,运行速度还是挺快的

python -m test14ch.nodule_analysis 1.3.6.1.4.1.14519.5.2.1.6279.6001.592821488053137951302246128864

输出的结果如下,总共发现19个结节,其中17个是假阳性,1个良性,1个恶性。不过这里我还没有把恶性分类器加上,看了一下代码,这个恶性标记应该是从原始的标注数据里来的?


image.png

识别恶性结节

这个地方我们先获取一份关于恶性肿瘤的标注信息。这个数据来自我们之前已经安装的LIDC工具包,如果你还没安装可以用下面这行shell安装

pip install pylidc

这里面有医生关于恶性结节的标注,每个医生会标注非常不可能、不可能、不确定、可疑、非常可疑几种情况,同时,对于一个结节可能会有多个医生进行标注。

ROC与AUC

这里插播一个小知识,就是需要学一个新的评估指标,ROC-AUC。要了解ROC曲线(Receiver Operating Characteristic curve),我们先回到混淆矩阵上来。


image.png

根据混淆矩阵,可以算出真阳性率(True Positive Rate,TPR)和假阳性率(False Positive Rate,FPR)
其中,


image.png

而ROC曲线就是假设我们对判断阳性取不同的阈值时以两个值为横纵坐标得到的一条曲线。下图中,我们假设使用一个结节的直径大小作为判断是否恶性结节的标准,那么该曲线就是当取不同的直径大小作为判断阈值时,所得到的ROC曲线。而AUC(Area Under ROC)即为ROC曲线下面的面积。ROC曲线是用来衡量分类器的分类能力,AUC表示,随机抽取一个正样本和一个负样本,分类器正确给出正样本的score高于负样本的概率。
image.png

因此,如果AUC越大,则表示模型的效果越好。

import torch
%matplotlib inline
from matplotlib import pyplot

import test14ch.dsets
import test14ch.model
#这里获取带有恶性结节标记的数据集
ds = test14ch.dsets.MalignantLunaDataset(val_stride=10, isValSet_bool=True)  
nodules = ds.ben_list + ds.mal_list
#获取是否恶性结节的状态和直径
is_mal = torch.tensor([n.isMal_bool for n in nodules]) 
diam  = torch.tensor([n.diameter_mm for n in nodules])
#恶性和良性数目
num_mal = is_mal.sum()  
num_ben = len(is_mal) - num_mal
#设置阈值,取结节直径的最大最小值,并分成100份
threshold = torch.linspace(diam.max(), diam.min(), steps=100)
#使用直径来判断是否恶性结节
predictions = (diam[None] >= threshold[:, None])  
计算真阳率和假阳率
tp_diam = (predictions & is_mal[None]).sum(1).float() / num_mal  
fp_diam = (predictions & ~is_mal[None]).sum(1).float() / num_ben
#计算auc
fp_diam_diff =  fp_diam[1:] - fp_diam[:-1]
tp_diam_avg  = (tp_diam[1:] + tp_diam[:-1])/2
auc_diam = (fp_diam_diff * tp_diam_avg).sum()
#fill用于后面绘图使用
fp_fill = torch.ones((fp_diam.shape[0] + 1,))
fp_fill[:-1] = fp_diam

tp_fill = torch.zeros((tp_diam.shape[0] + 1,))
tp_fill[:-1] = tp_diam

print(threshold)
print(fp_diam)
print(tp_diam)

for i in range(threshold.shape[0]):
    print(i, threshold[i], fp_diam[i], tp_diam[i])

pyplot.figure(figsize=(7,5), dpi=1200)
for i in [62, 88]:
    pyplot.scatter(fp_diam[i], tp_diam[i], color='red')
    print(f'diam: {round(threshold[i].item(), 2)}, x: {round(fp_diam[i].item(), 2)}, y: {round(tp_diam[i].item(), 2)}')
pyplot.fill(fp_fill, tp_fill, facecolor='#0077bb', alpha=0.25)
pyplot.plot(fp_diam, tp_diam, label=f'diameter baseline, AUC={auc_diam:.3f}')
pyplot.title(f'ROC diameter baseline, AUC={auc_diam:.3f}')
pyplot.ylabel('true positive rate')
pyplot.xlabel('false positive rate')
pyplot.savefig('roc_diameter_baseline.png')

最后绘制的图像就是我们上面展示过的图像,这个就作为我们预测是否恶性的baseline,接下来使用模型finetune来训练一个预测是否恶性的模型。


image.png
finetune

这里使用的是之前的分类模型,我们在分类模型的基础上进行微调。下图显示了微调的方案,如果微调深度为1,那么只是把最后的全连接层重新训练,这里保持模型的主干和尾部都维持之前的权重不动,只把全连接层的权重重新初始化,然后把模型目标改成区分是否恶性结节,让模型去学习新的权重。如果微调深度为2,那么把主干最后一个卷积块的参数也都重置,再进行训练。


image.png

在训练分类的代码里,通过简单的修改就可以实现微调。

#判断是否开启微调模式
        if self.cli_args.finetune:
#加载模型
            d = torch.load(self.cli_args.finetune, map_location='cpu')
#获取所有层
            model_blocks = [
                n for n, subm in model.named_children()
                if len(list(subm.parameters())) > 0
            ]
#获取需要微调的层
            finetune_blocks = model_blocks[-self.cli_args.finetune_depth:]
            log.info(f"finetuning from {self.cli_args.finetune}, blocks {' '.join(finetune_blocks)}")
#加载保存的状态
            model.load_state_dict(
                {
                    k: v for k,v in d['model_state'].items()
                    if k.split('.')[0] not in model_blocks[-1]
                },
                strict=False,
            )
#只在需要finetune的层上进行梯度计算
            for n, p in model.named_parameters():
                if n.split('.')[0] not in finetune_blocks:
                    p.requires_grad_(False)

然后启动模型重新训练,这里我们使用区分恶性的数据集。可以看到这里使用的是MalignantLunaDataset,第一次只调整最后的全连接层,训练40个epoch。

run('test14ch.training.ClassificationTrainingApp', f'--epochs=40', '--malignant', '--dataset=MalignantLunaDataset',
    '--finetune=''D:/pytorchtest/data-unversioned/part2/models/p2ch14/cls_2022-06-27_21.59.28_nodule-nonnodule.best.state',
    'finetune-head')

这个训练速度稍微快一点,第一个epoch效果不怎么好,训练集上两个类别准确率67%,验证集上平均下来只有60%左右。


image.png

到40个epoch


image.png

由于中间我也没盯着,后来看保存的最佳模型,是第29个epoch的模型
[图片上传失败...(image-7370ad-1656594433704)]

根据代码可以看到,评判模型的得分,我们这里用的是auc,可以看到第29个epoch在验证集上的auc要略高一点,所以这里最佳模型是第29个epoch的模型。


image.png

输出了它的效果跟之前的AUC对比一下,结果发现训了半天模型还不如之前就用结节的直径来判断得到的AUC更好。
image.png

不行,把finetune深度改成2,又训了40轮。
run('test14ch.training.ClassificationTrainingApp', f'--epochs=40', '--malignant', '--dataset=MalignantLunaDataset',
    '--finetune=''D:/pytorchtest/data-unversioned/part2/models/p2ch14/cls_2022-06-27_21.59.28_nodule-nonnodule.best.state',
    '--finetune-depth=2',
    'finetune-depth2')

直接看结果,这里的best model竟然是第4个epoch的模型,可以看到训练集上的准确率基本在93+%,验证集恶性准确率比较低只有73%,良性的准确率为87%。


image.png

到40个epoch的时候,可以看到在训练集上准确率都99%+了,但是验证集效果下滑,出现了过拟合现象。


image.png

结果这个AUC是有提升了,但是还没有超越使用直径直接分类的效果。不如我们实际的模型就用直径直接分类好了。当然,还有很多优化方案我们可以尝试,比如说做模型集成,做样本的增强,给训练数据提取更多特征,比如说做一个平滑的标签,甚至是使用更复杂的模型等等,但是在实际的项目中,我们可能说“由于时间紧迫,我们的第一版就先上线”,毕竟当下的效果已经能够满足业务需求,并且整体的逻辑已经完成,终于可以给这个模型训练阶段画上一个句号。
image.png
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 159,569评论 4 363
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,499评论 1 294
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 109,271评论 0 244
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,087评论 0 209
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,474评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,670评论 1 222
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,911评论 2 313
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,636评论 0 202
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,397评论 1 246
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,607评论 2 246
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,093评论 1 261
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,418评论 2 254
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,074评论 3 237
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,092评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,865评论 0 196
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,726评论 2 276
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,627评论 2 270

推荐阅读更多精彩内容