GAN在单细胞上的应用

下面申小忱的投稿

最近看到了几篇论文,通过GAN网络与单细胞分析相结合,进而得出一些新的思路。和大家分享一下,看过这篇推文相信大家一定会对GAN的原理,GAN的作用,以及GAN在单细胞方面的简单的作用有一个全方面的了解。市面上对于这方面的推文几乎很少,所以笔者也仅仅是根据自己的理解和看过的论文来撰写这篇推文,如果大家有兴趣,可以留言来做一个深入的交流讨论!

xjtu 申小忱

我将以这篇文章为一个引领,通过这篇论文的实验结果来为大家指明一条学习路径。以及我会基于pytorch 做一些简单的实验让大家深入体会一下GAN的魅力~~~ 首先给大家介绍一下这篇论文主要做了一件什么样的简单的事情。
这篇论文作者说现在测单细胞很贵啦!单细胞数据很少啦 !样本很珍贵啦 ! (当然作者说的很官方。。)
所以呢我们能不能基于已知的单细胞数据集生成一些细胞扩充一下我们的数据大小。比如我们测了1万个细胞,然后我们生成一些细胞。这样细胞数量就变多了哇!!又比如你关注某一亚群细胞。。她们可能比较少,你也可以同样专注于生成那些细胞同样可以。这也就是GAN所做的事情。

说了这么多什么是GAN啊!!!!

GAN全成生成式对抗网络。原理其实很简单的。大家不用怕~~~我先给大家普及一些初中高中知识。。

首先大家应该都知道什么决策树,随机森林,再不济总知道LASSO回归,逻辑回归,岭回归吧。那这些所谓的模型有一个共同的特点都是从属性x(基因表达量)预测标签y(是否患病或者预后咋样..)这是一种条件概率分布。已知条件推结论。

而GAN则与之不同是一种生成式模型任务是得到属性x且类别为y时的联合概率。这是一种联合概率分布。所以这种生成式模型学习的是一种分布。

像这种生成式网络现在主流有两个方向。一个就是各种GAN变体啦,还有就是VAE(变分自编码器)。因为这篇推文主要针对于GAN,所以想了解VAE的童鞋也可以去看看相应的论文。说实话蛮多的。。

这张图片也就是GAN究竟是什么样子的啦!GAN主要由两部分构成一个是生成器(generator),一个是判别器(discriminator)。大致流程就是,你输入随机噪声(一般服从高斯分布)然后生成器根据随机噪声生成一个东西例如你是研究cv的就是生成一张图片。然后判别器学习真实的图片和生成器生成的图片,真实的图片化为1类,生成器生成的图片化为0类。

生成器的目的就是使自己生成的东西尽量划分到1类。判别器的目的就是找出生成器生成的图片和真实图片的差异进而把生成器生成的图片划分到0类。周而复始,直到判别器无法分辨给出的概率是1/2。也就是达到了生成器生成的图片和真实图片高度类似。

这就是一个生成器和判别器对抗(博弈)的过程。同样回到单细胞中,你生成的就不是图片了而生成的一个个细胞,道理是一样的。你可能会说那怎么能一样!!!细胞是一行数据。图片是一个二维的啊。。。(⊙o⊙)…回答这类问题就是你给他按照统一规则reshape一下不就好啦??原文也是这么做哒~~ 这

里主要是给大家提供一个思路,我也不太可能帮大家想出一个课题!很多细节上的问题,我觉得问别人是解决不了的,要多看论文多体会。。。

回到GAN中。知道了GAN的大致做法。那我就要给大家介绍一下GAN的价值函数。以及如何求解。

这是一个minmax函数,在这个函数中D(判别器)的目的是最大化价值函数V,而G(生成器)的目的是针对特定的D去最小化价值函数V。

data:真实数据  z:随机噪声

其中D的目标,log是一个单挑递增的函数,也就是最大化D(x)和 1-D(G(z))而 G的目标就是在特定D条件下使D(G(z))尽量趋近于1.

在D取最优解时可得:

也就是

在判别器取最优解时,生成器的最优解就是
此时 价值函数的值为

这里再拓展一下,大家可以看到我上面给大家的示意图,有一个label对吧。这就是一个GAN变体,除了随机噪声同时输入了一个条件label。这就是下一篇论文了 conditional generative adversarial nets。在原有价值函数上进行改进输入一个条件。

这也就是 cGAN 在原有的GAN上面加入类别标签。这也就是这篇单细胞文章中所涉及的 cscGAN。这个的主要目的也就是生成你想要类别的固定数据,比如你就想生成CD4, NK等等。。这样你就可以输入固定条件生成一些你想要的细胞类别。

上面是简要的文章背景普及,接下里我们来看看这篇单细胞论文的一些结果。

首先这是作者用所谓的scGAN(也就是训练的普通的GAN啦)。来学习每个细胞,进而画出学习出来的整个数据集的一个tSNE的降维聚类图。可以看到无论整体形状还是某一种基因在总体细胞的分布,总体上差异不是很大,作者是以NKG7举的例子。其实随着GAN的发展,想做到比这个优秀,还是较为容易的。

作者同时为了说明生成的细胞和原来测序的细胞之间没有很大的差异,用了一个简单的随机森林分类器来做一个分类。

可以看到整体上还是比较趋近于0.5的,说明无法分开,进而证明生成的细胞和真的细胞差异还是很小的,用弱分类器还是较难区分的。

接着作者又设计了一个cscGAN也就是我上面提到的加入一个label来做。来生成固定类别的细胞。

选择其中两类,一个2类一个6类。看到conditional GAN 表现的同样很好。可以生成对应聚类的细胞群。

接下来给大家介绍一下cGAN的代码,由于我自己写的还有不对的地方请大家多多指教。因为原文是用TensorFlow写的,而我平时还是用torch框架的。我就简单的给大家演示一下基于手写数字集的条件生成对抗网络大概的流程。

import argparse
import os
import numpy as np
import torch
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
opt = parser.parse_known_args()[0]
print("opt =", opt)

首先导入一些库还有我自己的一些参数设置(这个其实无所谓)。你要是勤劳后面一个一个设置超参数是一样的。。

img_shape = (opt.channels, opt.img_size, opt.img_size)
print("img_shape =", img_shape)
# img_shape = (1, 28, 28)

因为mnist 数据集是灰度图像只有一个通道且长宽皆为28*28。

import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
#os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),

batch_size=opt.batch_size,
    shuffle=True,
)

导入数据,为了大家方便复现,所以我才用mnist,因为这数据很小的。也就6万张图片(训练集)。同时我做了简单的transform操作就是都resize 成28*28 然后归一化一下。

from torch.autograd import Variable
import matplotlib.pyplot as plt

def show_img(img, trans=True):
    if trans:
        img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0))  
        plt.imshow(img[:, :, 0], cmap="gray")
    else:
        plt.imshow(img, cmap="gray")
    plt.show()
    
mnist = datasets.MNIST("../../data/mnist")

for i in range(3):
    sample = mnist[i][0]
    label = mnist[i][1]
    show_img(np.array(sample), trans=False)
    print("label =", label, '\n')

这一步大家可以看到图片样式看看数据样子嘛。

也就是这样的灰度图。

接下来造一个生成器和判别器。因为我做的是一个cGAN嘛。就是输入随机噪声和一个label.然后label也就是一个0-9 10个label。把它们one-hot encoding一下 然后cat到随机噪声上就可以啦~~

import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)

def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

self.model = nn.Sequential(
            *block(opt.latent_dim + opt.n_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

def forward(self, noise, labels):
        # Concatenate label embedding and image to produce input
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *img_shape)
        return img
    
generator = Generator()

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)

self.model = nn.Sequential(
            nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
            nn.Sigmoid(),
        )

def forward(self, img, labels):
        
        d_in = torch.cat((img.view(img.size(0),-1),self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        return validity
    
discriminator = Discriminator()

接下来定义一下我的损失函数 我这里用二分类交叉熵损失来简单实现~~

adversarial_loss = torch.nn.BCELoss()

把定义的生成器和判别器放到GPU上,同时定义优化器,我这里有adam。你也可以用SGD,RMsprop 等等啦。因为用了adam初始学习率就要设小一点我这里设置为0.002

cuda = True if torch.cuda.is_available() else False
print("cuda_is_available =", cuda)
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
print("learning_rate =", opt.lr)
#learning_rate = 0.0002

接下来再定义一个输出函数。我准备让生成器依次输出0-9。输出10遍。看看效果

from torchvision.utils import save_image

def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)
    #这里nrow 取10

接下来就是迭代训练了。我这里设置迭代200次(详见最开始我的超参数设定)输入随机噪声服从高斯分布

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in list(enumerate(dataloader)):
    # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))
        # Sample noise and labels as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
        batch_size = imgs.shape[0]
        gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))
        optimizer_G.zero_grad()
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False) # 为1时判定为真
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False) # 为0时判定为假
        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid)
        g_loss.backward()
        optimizer_G.step()
        optimizer_D.zero_grad()

# Loss for real images
        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real, valid)

# Loss for fake images
        validity_fake = discriminator(gen_imgs.detach(), gen_labels)

d_fake_loss = adversarial_loss(validity_fake, fake)
        d_loss = (d_real_loss + d_fake_loss) / 2
        #print("real_loss =", d_real_loss, '\n')
        #print("fake_loss =", d_fake_loss, '\n')
        #print("d_loss =", d_loss, '\n')

d_loss.backward()
        optimizer_D.step()
        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            os.makedirs("images", exist_ok=True)
            sample_image(n_row=10, batches_done=batches_done)

os.makedirs("model", exist_ok=True) # 保存模型
            torch.save(generator, 'model/generator.pkl') 
            torch.save(discriminator, 'model/discriminator.pkl')

我这里设定400个batch就打印一下结果。可以让大家看一下渐变过程还是蛮有意思的。
可以看到这是一个逐渐博弈的过程。进而生成的图像越来越好。同样你把每一个手写数字换成单细胞每个细胞的数据一样可以的得到很好的结果。因为大家知道同样的cluster的细胞在某些marker上表达的程度是一致的也就是像手写数字一样的在某些位置像素值特别高趋于255就是白色,表达值低就像素很小趋于0就是黑色。大同小异有没有!!!
通过本篇推文笔者还是希望和更多有志之士进行交流,取长补短日益精进自己的能力素养。希望大家能多多指出我的不足,我一定虚心向大家学习!希望大家能多多与我交流。

写在文末

虽然说上面的代码都是复制粘贴即可运行,但是如果要更好地完成上面的图表,通常是需要掌握5个R包,分别是: scater,monocle,Seurat,scran,M3Drop,需要熟练掌握它们的对象,:一些单细胞转录组R包的对象 而且分析流程也大同小异:

  • step1: 创建对象
  • step2: 质量控制
  • step3: 表达量的标准化和归一化
  • step4: 去除干扰因素(多个样本整合)
  • step5: 判断重要的基因
  • step6: 多种降维算法
  • step7: 可视化降维结果
  • step8: 多种聚类算法
  • step9: 聚类后找每个细胞亚群的标志基因
  • step10: 继续分类

单细胞转录组数据分析的标准降维聚类分群,并且进行生物学注释后的结果。可以参考前面的例子:人人都能学会的单细胞聚类分群注释 ,我们演示了第一层次的分群。

如果你对单细胞数据分析还没有基础认知,可以看基础10讲:

单细胞服务列表

(0)

相关推荐