一文快速入手:多实例学习

翻译自Medium:https://medium.com/swlh/multiple-instance-learning-c49bd21f5620

导读

当涉及到在医学领域中应用计算机视觉时,大多数任务涉及到:

(1) 用于诊断的图像分类任务

(2) 识别和分离病变区域的分割任务

然而,在病理学癌症检测中,这并不总是可能的。获取标签既费时又费力。此外,病理切片的分辨率最高可达200000 x 100000像素,并且它们不适合在内存中进行分类,因为例如,ImageNet仅使用224 x 224像素进行训练。下采样通常不是一个选项,因为我们试图检测一个微小的区域,例如从300×300像素区域(图1中的几个点)变化的癌区域。

图一:来自patient_ 004 _ node _ 004(cameloyon 17)的幻灯片

在这种情况下,我们可以使用多实例学习(Multiple Instance Learning),这是一种弱监督学习方法,它采用一组包含许多实例的标记包,而不是接收一组标记实例。

假设我们有病理切片和每张切片的标签。因为我们不能在整个幻灯片上训练分类器,所以我们将每个幻灯片分成小块,在GPU上一次只处理几个小块。然而,我们不知道每个图块的标签,因此我们需要多实例学习。在MIL框架中,幻灯片是“包”,切片是“实例”。通过使用它,我们能够节省标记工作,并利用弱标记数据。

当我们有患者的病理切片时,我们希望预测大切片是否包含癌细胞,或者缩小患者是否有恶性细胞,多实例学习是一个很好的选择,因为医生不需要分割单个细胞或标记每个切片。只有整张幻灯片需要标签。

一般来说,多实例学习可以处理分类问题、回归问题、排序问题和聚类问题,但我们这里主要关注分类问题。

在这篇文章中,我将通过一个基于 MNIST 数据集的简单示例来解释 MIL 如何工作。如果你不熟悉 MNIST 数据集,这里有一个[关于 MNIST 数据集](https://www.kaggle.com/ngbolin/mnist-dataset-digit-recognizer)的[Kaggle 竞赛](https://www.kaggle.com/ngbolin/mnist-dataset-digit-recognizer)的链接,你可以看看。

MNIST数据集简介

MNIST数据集是一个手写数字的大型数据库,每个图像都有一个从0到9的标签。它有6万张图像的训练集和1万张图像的测试集。每个的尺寸是28 x 28的灰度图。

图 2: Minst 手写分类数据集

多实例学习的问题简述

一个袋子里的xi每个实例都有一个标签yi。我们将包的标签定义为:

Y = 1,如果存在 yi ==1

Y = 0,如果对于每个yi,yi == 0

在MNIST数据集上应用多元线性回归的流程

作为概述,我们将首先在实例标记的数据集上预训练 ResNet 模型,然后将袋子标记的数据集馈送到模型中并提取特征。最后,我们对它们应用 MIL。
步骤 1:将原始 MNIST 数据集拆分为袋标记集以进行适当的 MIL 训练和实例标记集。
创建 MIL-MNIST 玩具数据集:
MNIST 数据集是分类任务的标准基准。为了使它成为一个 MIL 问题,我们需要首先通过将几个数字(实例)分组到一个包中来构建 MIL-MNIST 数据集。
在我们的 MNIST 示例中,如果一个实例的标签为“1”,我们会将袋子标签分配为“1”;如果除“1”之外的所有实例标签都是“0-9”,那么我们将袋子标签分配为“0”。
如下图所示,红色填充的袋子的袋子标签为“1”,蓝色填充的袋子的袋子标签为“0”。

图 3:袋子和实例标签

我们将每个图像随机放入一个包中,每个包包含 3 到 7 个实例。为了节省内存,我们使用索引来表示图像(如下图)。

    def data_generation(instance_index_label: List[Tuple]) -> List[Dict]: ''' bags: {key1: [ind1, ind2, ind3], key2: [ind1, ind2, ind3, ind4, ind5], ... } bag_lbls: {key1: 0, key2: 1, ... } ''' bag_size = np.random.randint(3,7,size=len(instance_index_label)//5) data_cp = copy.copy(instance_index_label) np.random.shuffle(data_cp) bags = {} bags_per_instance_labels = {} bags_labels = {} for bag_ind, size in enumerate(bag_size): bags[bag_ind] = [] bags_per_instance_labels[bag_ind] = [] try: for _ in range(size): inst_ind, lbl = data_cp.pop() bags[bag_ind].append(inst_ind) # simplfy, just use a temporary variable instead of bags_per_instance_labels bags_per_instance_labels[bag_ind].append(lbl) bags_labels[bag_ind] = bag_label_from_instance_labels(bags_per_instance_labels[bag_ind]) except: break return bags, bags_labels

    生成包标签:

      def bag_label_from_instance_labels(instance_labels):    return int(any(((x==1) for x in instance_labels)))

      第 2 步:对 MNIST 数据集的 2 个部分进行预训练

      1. 构造一个2D卷积神经网络,kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)

      2. 训练 5 个 epoch,批大小为 256

      3. 保存模型

        import torchfrom torchvision.models.resnet import ResNet, BasicBlockclass MnistResNet(ResNet): def __init__(self): super(MnistResNet, self).__init__(BasicBlock, [2, 2, 2, 2], num_classes=10) self.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) def forward(self, x): return torch.softmax(super(MnistResNet, self).forward(x), dim=-1)

        第 3 步:加载预训练模型并从最后一层提取特征

        1. 将其余数据拆分为训练、验证和测试集

        2. 获取训练、验证和测试集的特征

        3. 获取 bag_indices 和 bag_labels

        4. 使用基于索引的特征映射 bag_indices 并创建 bag_features

        为了摆脱最后一层:

          model = MnistResNet()model.load_state_dict(torch.load('mnist_state.pt'))body = nn.Sequential(*list(model.children()))# extract the last layermodel = body[:9]# the model we will usemodel.eval()

          提取特征

          下面的代码展示了我们如何从数据生成函数中获取包索引和包特征:

            bag_indices, bag_labels = data_generation(instance_index_label)bag_features = {kk: torch.Tensor(feature_array[inds]) for kk, inds in bag_indices.items()}

            袋子索引、袋子标签和袋子特征如下所示:

            图 7:带图像索引的袋子索引

            图 8:袋子标签

            图 9:袋子特征

            第 4 步:在 bag_features 和 bag_labels 上训练 MIL 模型并在测试集上进行评估

            由于每个包都有不同数量的实例,我们需要在将张量放入模型之前将它们填充到相同的大小。

            多实例学习模型:

            该算法执行三个步骤。它们中的任何一个都可以是固定函数或可优化函数(神经网络):

            1. 将实例转换为低维嵌入。(固定的)

            2. 通过置换不变聚合函数传递嵌入。(可优化)

            3. 转化为包概率。(可优化)

            图 9:MIL-MNIST 玩具数据集上的 MIL 图

            一般来说,工作流程如下:

            图 10:病理切片上的 MIL 算法框架图(参见参考文献 #5)

            为简单起见,我们将步骤 1 固定为固定。对于第 2 步,虽然我们仍然可以使用固定函数,例如 max 或 mean,但为了启用可以通过反向传播端到端学习的参数优化,我们使用神经网络作为聚合函数。对于第 3 步,我们还希望使用反向传播来优化参数。

            1. 线性层和 LeakyReLu

              class NoisyAnd(torch.nn.Module):    def __init__(self, a=10, dims=[1,2]):        super(NoisyAnd, self).__init__()#         self.output_dim = output_dim        self.a = a        self.b = torch.nn.Parameter(torch.tensor(0.01))        self.dims =dims        self.sigmoid = nn.Sigmoid()    def forward(self, x):#         h_relu = self.linear1(x).clamp(min=0)        mean = torch.mean(x, self.dims, True)        res = (self.sigmoid(self.a * (mean - self.b)) - self.sigmoid(-self.a * self.b)) / (              self.sigmoid(self.a * (1 - self.b)) - self.sigmoid(-self.a * self.b))        return res    class NN(torch.nn.Module):    def __init__(self, n=512, n_mid = 1024,                 n_out=1, dropout=0.2,                 scoring = None,                ):        super(NN, self).__init__()        self.linear1 = torch.nn.Linear(n, n_mid)        self.non_linearity = torch.nn.LeakyReLU()        self.linear2 = torch.nn.Linear(n_mid, n_out)        self.dropout = torch.nn.Dropout(dropout)        if scoring:            self.scoring = scoring        else:            self.scoring = torch.nn.Softmax() if n_out>1 else torch.nn.Sigmoid()            def forward(self, x):        z = self.linear1(x)        z = self.non_linearity(z)        z = self.dropout(z)        z = self.linear2(z)        y_pred = self.scoring(z)        return y_pred    class LogisticRegression(torch.nn.Module):    def __init__(self, n=512, n_out=1):        super(LogisticRegression, self).__init__()        self.linear = torch.nn.Linear(n, n_out)        self.scoring = torch.nn.Softmax() if n_out>1 else torch.nn.Sigmoid()    def forward(self, x):        z = self.linear(x)        y_pred = self.scoring(z)        return y_pred    def regularization_loss(params,                        reg_factor = 0.005,                        reg_alpha = 0.5):    params = [pp for pp in params if len(pp.shape)>1]    l1_reg = nn.L1Loss()    l2_reg = nn.MSELoss()    loss_reg =0    for pp in params:        loss_reg+=reg_factor*((1-reg_alpha)*l1_reg(pp, target=torch.zeros_like(pp)) +\                           reg_alpha*l2_reg(pp, target=torch.zeros_like(pp)))    return loss_reg

              注意:我们设置 n = 7*512,其中 7 是一个包中的实例数,512 是每个特征的大小。

              2. 聚合函数:AttensionSoftmax

                class SoftMaxMeanSimple(torch.nn.Module): def __init__(self, n, n_inst, dim=0): ''' if dim==1: given a tensor `x` with dimensions [N * M], where M -- dimensionality of the featur vector (number of features per instance) N -- number of instances initialize with `AggModule(M)` returns: - weighted result: [M] - gate: [N] if dim==0: ... ''' super(SoftMaxMeanSimple, self).__init__() self.dim = dim self.gate = torch.nn.Softmax(dim=self.dim) self.mdl_instance_transform = nn.Sequential( nn.Linear(n, n_inst), nn.LeakyReLU(), nn.Linear(n_inst, n), nn.LeakyReLU(), ) def forward(self, x): z = self.mdl_instance_transform(x) if self.dim==0: z = z.view((z.shape[0],1)).sum(1) elif self.dim==1: z = z.view((1, z.shape[1])).sum(0) gate_ = self.gate(z) res = torch.sum(x* gate_, self.dim) return res, gate_
                class AttentionSoftMax(torch.nn.Module): def __init__(self, in_features = 3, out_features = None): ''' given a tensor `x` with dimensions [N * M], where M -- dimensionality of the featur vector (number of features per instance) N -- number of instances initialize with `AggModule(M)` returns: - weighted result: [M] - gate: [N] ''' super(AttentionSoftMax, self).__init__() self.otherdim = '' if out_features is None: out_features = in_features self.layer_linear_tr = nn.Linear(in_features, out_features) self.activation = nn.LeakyReLU() self.layer_linear_query = nn.Linear(out_features, 1) def forward(self, x): keys = self.layer_linear_tr(x) keys = self.activation(keys) attention_map_raw = self.layer_linear_query(keys)[...,0] attention_map = nn.Softmax(dim=-1)(attention_map_raw) result = torch.einsum(f'{self.otherdim}i,{self.otherdim}ij->{self.otherdim}j', attention_map, x) return result, attention_map

                3. 中间以LeakyReLu为激活函数,dropout,sigmoid为最终激活函数的神经网络:

                  class MIL_NN(torch.nn.Module):    def __init__(self, n=512,                   n_mid=1024,                  n_classes=1,                  dropout=0.1,                 agg = None,                 scoring=None,                ):        super(MIL_NN, self).__init__()        self.agg = agg if agg is not None else AttentionSoftMax(n)                if n_mid == 0:            self.bag_model = LogisticRegression(n, n_classes)        else:            self.bag_model = NN(n, n_mid, n_classes, dropout=dropout, scoring=scoring)            def forward(self, bag_features, bag_lbls=None):        '''        bag_feature is an aggregated vector of 512 features        bag_att is a gate vector of n_inst instances        bag_lbl is a vector a labels        figure out batches        '''        bag_feature, bag_att, bag_keys = list(zip(*[list(self.agg(ff.float())) + [idx]                                                    for idx, ff in (bag_features.items())]))        bag_att = dict(zip(bag_keys, [a.detach().cpu() for a  in bag_att]))        bag_feature_stacked = torch.stack(bag_feature)        y_pred = self.bag_model(bag_feature_stacked)        return y_pred, bag_att, bag_keys

                  4. 优化器:SGD

                  5. 损失函数:BCELoss

                  6. 准确度:~0.99

                  结论

                  我们使用 MIL 在 MNIST 数据集上获得了大约 0.99 的准确率,这是一个令人满意的结果。如果我们愿意,我们可以使用更复杂的聚合函数作为我们的中间转换,并构建更复杂的 NN 模型用于最终转换到包级别。结果还表明,MIL 是一个很好的工具,可以节省标记工作并利用弱标记数据。

                  Jupyter 笔记本演示链接:

                  https://github.com/lsheng23/Practicum/blob/master/MIL_MNIST/end_to_end_mnist_MIL.ipynb
                  (0)

                  相关推荐

                  • 搞懂Vision Transformer 原理和代码,看这篇技术综述就够了(三)

                    作者丨科技猛兽 来源丨极市平台 审核丨邓富城 极市导读 本文为详细解读Vision Transformer的第三篇,主要解读了两篇关于Transformer在识别任务上的演进的文章:DeiT与VT.它 ...

                  • 【生成模型】简述概率密度函数可处理流模型

                    本期将介绍第二种非常优雅的生成模型-流模型,它也是一种概率密度函数可处理的生成模型.本文将对其原理进行介绍,并对nice模型的源码进行讲解. 作者&编辑 | 小米粥 1 流模型 这是一种想法比 ...

                  • 项目实践 | 基于YOLO-V5实现行人社交距离风险提示(文末获取完整源码)

                    由于YOLO V5的作者现在并没有发表论文,因此只能从代码的角度理解它的工作.YOLO V5的网络结构图如下: 1.与YOLO V4的区别 Yolov4在Yolov3的基础上进行了很多的创新.比如输入 ...

                  • PyTorch 源码解读之 torch.autograd

                      磐创AI分享   来源 | GiantPandaCV 作者 | OpenMMLab 来源 | https://zhuanlan.zhihu.com/p/321449610 前言 本篇笔记以介绍 p ...

                  • 小白学PyTorch | 4 构建模型三要素与权重初始化

                    文章目录: 1 模型三要素 2 参数初始化 3 完整运行代码 4 尺寸计算与参数计算 这篇文章内容不多,比较基础,里面的代码块可以复制到本地进行实践,以加深理解. 喜欢的话,可以给公众号加一个星标,点 ...

                  • PyTorch Lightning工具学习

                    来源 | GiantPandaCV 编辑 | pprp [导读]Pytorch Lightning是在Pytorch基础上进行封装的库(可以理解为keras之于tensorflow),为了让用户能够脱 ...

                  • 深度学习之PyTorch实战(3)

                    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

                  • 一文快速了解安全阀,值得收藏!

                    通常,普通的阀门只是起到一个开关的作用,而安全阀除了有开关的作用外,还可以起到保护设备.防止事故发生的作用,是特种设计上非常重要的安全附件.今天,小编就来带大家一起了解安全阀的有关知识. 一.安全阀分 ...

                  • 洛阳市冷知识20条,一文快速了解洛阳这座城市

                    洛阳市,河南省下辖的地级市,河洛文化和华夏文明的发源地,重要的工业城市和优秀旅游城市,中原经济区和中原城市群的副中心城市. 洛阳位于洛水之北,故名洛阳:因境内有伊.洛两水,也称伊洛.洛阳地处中原,雄踞 ...

                  • 减震技术丨一文快速读懂位移放大型阻尼器的理论研究及工程应用

                    摘要 /ABSTRACT/ 传统的位移相关型阻尼器在结构变形较小时往往难以充分发挥工作性能.为此,可采用特殊的构造形式和安装方法,将结构的变形或位移放大数倍后再输入阻尼器,进而提高阻尼器的工作效率和耗 ...

                  • 隶书《曹全碑》简易教程,让书法快速入手

                    原创太一智慧书画艺术2019-02-06 07:32:17 <曹全碑>是汉代隶书的经典代表作之一,其艺术风格清秀飘逸,字体俊美,是隶书书法中的典型之作,它的结体是蚕头燕尾,点画细朗,书写容 ...

                  • 学易心得:初学奇门的常用思路【快速入手】

                    释缘易学 弘扬中国易经传统文化,研究.传承和发扬传统易经文化精髓:传播正确的命运观.人生观和价值观,竭尽所能为促进中华民族"文化自信"贡献绵薄之力. 96篇原创内容 公众号 奇门遁 ...

                  • 一文快速掌握多巴胺常用泵速与配制公式

                    多巴胺主要用于治疗心源性休克.各种原因导致的轻度血压下降,如何快速计算其使用剂量是一个难点,丁香公开课讲师,中山大学附属孙逸仙纪念医院心内科副主任医师麦憬霆总结了多巴胺的常用泵速与配制公式. 不同剂量 ...

                  • 建盏火了之后市场愈发复杂,如何快速入手鉴别真伪?

                    建盏成为国非遗和地理保护标志之后,埋土里几百年的旧招牌一跃成为金字招牌,浑水摸鱼的也多了,网络上一搜,写作建盏却和建盏只有"黑色"这一个共同点的很多. 想要快速分辨,可以从釉色.胎 ...

                  • 一文快速了解中国5000年历史

                    历史是人类最好的镜子,唯有透过历史,才能看清楚人类文明的一点一滴.只要不忘记走过的道路,就可以少几次跌倒. 中国有过辉煌,也有过没落,但中国始终存在.中国历史是一个国家的历史,是一个民族的历史,也是每 ...

                  • premiere 【PR】零基础快速入手详细教程

                    传说的清风2018-11-27 21:06 premiere [PR]零基础快速入手详细教程. 方法 1 打开PR. 2 点击新建文件. 3 设置参数. 4 PR主界面有下图5个主要区域. 5 双击素 ...