换脸火了,我用 python 快速入门生成模型

机器学习算法与Python学习

作为沟通学习的平台,发布机器学习与数据挖掘、深度学习、Python实战的前沿与动态,欢迎机器学习爱好者的加入,希望帮助你在AI领域更好的发展,期待与你相遇!
86篇原创内容
公众号

点击 机器学习算法与Python学习 选择加星标

精彩内容不迷路

出品 | AI科技大本营(ID:rgznai100)

引言:

近几年来,GAN生成对抗式应用十分火热,不论是抖音上大火的“蚂蚁牙黑”还是B站上的“复原老旧照片”以及换脸等功能,都是基于GAN生成对抗式的模型。但是GAN算法对于大多数而言上手较难,故今天我们将使用最少的代码,简单入门“生成对抗式网络”,实现用GAN生成数字。

其中生成的图片效果如下可见:

1. 模型建立

1.1 环境要求
本次环境使用的是python3.6.5+windows平台
主要用的库有:
  • OS模块用来对本地文件读写删除、查找到等文件操作

  • numpy模块用来矩阵和数据的运算处理,其中也包括和深度学习框架之间的交互等

  • Keras模块是一个由Python编写的开源人工神经网络库,可以作为Tensorflow、Microsoft-CNTK和Theano的高阶应用程序接口,进行深度学习模型的设计、调试、评估、应用和可视化 。在这里我们用来搭建网络层和直接读取数据集操作,简单方便

  • Matplotlib模块用来可视化训练效果等数据图的制作

1.2 GAN简单介绍
GAN 由生成器 (Generator)和判别器 (Discriminator) 两个网络模型组成,这两个模型作用并不相同,而是相互对抗。我们可以很简单的理解成,Generator是造假的的人,Discriminator是负责鉴宝的人。正是因为生成模型和对抗模型的相互对抗关系才称之为生成对抗式。
那我们为什么不适用VAE去生成模型呢,又怎么知道GAN生成的图片会比VAE生成的更优呢?问题就在于VAE模型作用是使得生成效果越相似越好,但事实上仅仅是相似却只是依葫芦画瓢。而 GAN 是通过 discriminator 来生成目標,而不是像 VAE线性般的学习。
这个项目里我们目标是训练神经网络生成新的图像,这些图像与数据集中包含的图像尽可能相近,而不是简单的复制粘贴。神经网络学习什么是图像的“本质”,然后能够从一个随机的数字数组开始创建它。其主要思想是让两个独立的神经网络,一个产生器和一个鉴别器,相互竞争。生成器会创建与数据集中的图片尽可能相似的新图像。判别器试图了解它们是原始图片还是合成图片。
1.3 模型初始化
在这里我们初始化需要使用到的变量,以及优化器、对抗式模型等。
def __init__(self, width=28, height=28, channels=1):
    self.width = width
    self.height = height
    self.channels = channels
    self.shape = (self.width, self.height, self.channels)
    self.optimizer = Adam(lr=0.0002, beta_1=0.5, decay=8e-8)
    self.G = self.__generator()
    self.G.compile(loss='binary_crossentropy', optimizer=self.optimizer)
    self.D = self.__discriminator()
    self.D.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy'])
    self.stacked_generator_discriminator = self.__stacked_generator_discriminator()
    self.stacked_generator_discriminator.compile(loss='binary_crossentropy', optimizer=self.optimizer)
1.4 生成器模型的搭建
这里我们尽可能简单的搭建一个生成器模型,3个完全连接的层,使用sequential标准化。神经元数分别是256,512,1024等:
def __generator(self):
        ''' Declare generator '''
        model = Sequential()
        model.add(Dense(256, input_shape=(100,)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(self.width  * self.height * self.channels, activation='tanh'))
        model.add(Reshape((self.width, self.height, self.channels)))
        return model
1.5 判别器模型的搭建
在这里同样简单搭建判别器网络层,和生成器模型类似:
def __discriminator(self):
    ''' Declare discriminator '''
    model = Sequential()
    model.add(Flatten(input_shape=self.shape))
    model.add(Dense((self.width * self.height * self.channels), input_shape=self.shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(np.int64((self.width * self.height * self.channels)/2)))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    model.summary()
    return model
1.6 对抗式模型的搭建
这里是较为难理解的部分。让我们创建一个对抗性模型,简单来说这只是一个后面跟着一个鉴别器的生成器。注意,在这里鉴别器的权重被冻结了,所以当我们训练这个模型时,生成器层将不受影响,只是向上传递梯度。代码很简单如下:
def __stacked_generator_discriminator(self):
    self.D.trainable = False
    model = Sequential()
    model.add(self.G)
    model.add(self.D)
    return model
2. 模型的训练使用

2.1 模型的训练
在这里,我们并没有直接去训练生成器。而是通过对抗性模型间接地训练它。我们将噪声传递给了对抗模型,并将所有从数据库中获取的图像标记为负标签,而它们将由生成器生成。
对真实图像进行预先训练的鉴别器把不能合成的图像标记为真实图像,所犯的错误将导致由损失函数计算出的损失越来越高。这就是反向传播发挥作用的地方。由于鉴别器的参数是冻结的,在这种情况下,反向传播不会影响它们。相反,它会影响生成器的参数。所以优化对抗性模型的损失函数意味着使生成的图像尽可能的相似,鉴别器将识别为真实的。这既是生成对抗式的神奇之处!
故训练阶段结束时,我们的目标是对抗性模型的损失值很小,而鉴别器的误差尽可能高,这意味着它不再能够分辨出差异。
最终在我门的训练结束时,鉴别器损失约为0.73。考虑到我们给它输入了50%的真实图像和50%的合成图像,这意味着它有时无法识别假图像。这是一个很好的结果,考虑到这个例子绝对不是优化的结果。要知道确切的百分比,我可以在编译时添加一个精度指标,这样它可能得到很多更好的结果实现更复杂的结构的生成器和判别器。
代码如下,这里legit_images是指原始训练的图像,而syntetic_images是生成的图像。:
def train(self, X_train, epochs=20000, batch = 32, save_interval = 100):
    for cnt in range(epochs):
        ## train discriminator
        random_index = np.random.randint(0, len(X_train) - np.int64(batch/2))
        legit_images = X_train[random_index : random_index + np.int64(batch/2)].reshape(np.int64(batch/2), self.width, self.height, self.channels)
        gen_noise = np.random.normal(0, 1, (np.int64(batch/2), 100))
        syntetic_images = self.G.predict(gen_noise)
        x_combined_batch = np.concatenate((legit_images, syntetic_images))
        y_combined_batch = np.concatenate((np.ones((np.int64(batch/2), 1)), np.zeros((np.int64(batch/2), 1))))
        d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch)
        # train generator
        noise = np.random.normal(0, 1, (batch, 100))
        y_mislabled = np.ones((batch, 1))
        g_loss = self.stacked_generator_discriminator.train_on_batch(noise, y_mislabled)
        print ('epoch: %d, [Discriminator :: d_loss: %f], [ Generator :: loss: %f]' % (cnt, d_loss[0], g_loss))
        if cnt % save_interval == 0:
            self.plot_images(save2file=True, step=cnt)
2.2 可视化
使用matplotlib来可视化模型训练效果。
def plot_images(self, save2file=False, samples=16, step=0):
    ''' Plot and generated images '''
    if not os.path.exists('./images'):
        os.makedirs('./images')
    filename = './images/mnist_%d.png' % step
    noise = np.random.normal(0, 1, (samples, 100))
    images = self.G.predict(noise)
    plt.figure(figsize=(10, 10))
    for i in range(images.shape[0]):
        plt.subplot(4, 4, i+1)
        image = images[i, :, :, :]
        image = np.reshape(image, [self.height, self.width])
        plt.imshow(image, cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    if save2file:
        plt.savefig(filename)
        plt.close('all')
    else:
        plt.show()
3. 使用方法

考虑到代码较少,下述代码复制粘贴即可运行。
# -*- coding: utf-8 -*-
import os
import numpy as np
from IPython.core.debugger import Tracer
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential
from keras.optimizers import Adam
import matplotlib.pyplot as plt
plt.switch_backend('agg')   # allows code to run without a system DISPLAY
class GAN(object):
    ''' Generative Adversarial Network class '''
    def __init__(self, width=28, height=28, channels=1):
        self.width = width
        self.height = height
        self.channels = channels
        self.shape = (self.width, self.height, self.channels)
        self.optimizer = Adam(lr=0.0002, beta_1=0.5, decay=8e-8)
        self.G = self.__generator()
        self.G.compile(loss='binary_crossentropy', optimizer=self.optimizer)
        self.D = self.__discriminator()
        self.D.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy'])
        self.stacked_generator_discriminator = self.__stacked_generator_discriminator()
        self.stacked_generator_discriminator.compile(loss='binary_crossentropy', optimizer=self.optimizer)
    def __generator(self):
        ''' Declare generator '''
        model = Sequential()
        model.add(Dense(256, input_shape=(100,)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(self.width  * self.height * self.channels, activation='tanh'))
        model.add(Reshape((self.width, self.height, self.channels)))
        return model
    def __discriminator(self):
        ''' Declare discriminator '''
        model = Sequential()
        model.add(Flatten(input_shape=self.shape))
        model.add(Dense((self.width * self.height * self.channels), input_shape=self.shape))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(np.int64((self.width * self.height * self.channels)/2)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()
        return model
    def __stacked_generator_discriminator(self):
        self.D.trainable = False
        model = Sequential()
        model.add(self.G)
        model.add(self.D)
        return model
    def train(self, X_train, epochs=20000, batch = 32, save_interval = 100):
        for cnt in range(epochs):
            ## train discriminator
            random_index = np.random.randint(0, len(X_train) - np.int64(batch/2))
            legit_images = X_train[random_index : random_index + np.int64(batch/2)].reshape(np.int64(batch/2), self.width, self.height, self.channels)
            gen_noise = np.random.normal(0, 1, (np.int64(batch/2), 100))
            syntetic_images = self.G.predict(gen_noise)
            x_combined_batch = np.concatenate((legit_images, syntetic_images))
            y_combined_batch = np.concatenate((np.ones((np.int64(batch/2), 1)), np.zeros((np.int64(batch/2), 1))))
            d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch)
            # train generator
            noise = np.random.normal(0, 1, (batch, 100))
            y_mislabled = np.ones((batch, 1))
            g_loss = self.stacked_generator_discriminator.train_on_batch(noise, y_mislabled)
            print ('epoch: %d, [Discriminator :: d_loss: %f], [ Generator :: loss: %f]' % (cnt, d_loss[0], g_loss))
            if cnt % save_interval == 0:
                self.plot_images(save2file=True, step=cnt)
    def plot_images(self, save2file=False, samples=16, step=0):
        ''' Plot and generated images '''
        if not os.path.exists('./images'):
            os.makedirs('./images')
        filename = './images/mnist_%d.png' % step
        noise = np.random.normal(0, 1, (samples, 100))
        images = self.G.predict(noise)
        plt.figure(figsize=(10, 10))
        for i in range(images.shape[0]):
            plt.subplot(4, 4, i+1)
            image = images[i, :, :, :]
            image = np.reshape(image, [self.height, self.width])
            plt.imshow(image, cmap='gray')
            plt.axis('off')
        plt.tight_layout()
        if save2file:
            plt.savefig(filename)
            plt.close('all')
        else:
            plt.show()
if __name__ == '__main__':
    (X_train, _), (_, _) = mnist.load_data()
    # Rescale -1 to 1
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=3)
    gan = GAN()
    gan.train(X_train)
(0)

相关推荐

  • Tensorflow实战:Discuz验证码识别

    选择"星标"公众号 重磅干货,第一时间送达! 写在最前面 验证码是根据随机字符生成一幅图片,然后在图片中加入干扰象素,用户必须手动填入,防止有人利用机器人自动批量注册.灌水.发垃圾 ...

  • MaskRCNN 基于OpenCV DNN的目标检测与实例分割

    原文:MaskRCNN 基于OpenCV DNN的目标检测与实例分割 - AIUAI 这里主要记录基于 OpenCV 4.x DNN 模块和 TensorFlow MaskRCNN 开源模型的目标检测 ...

  • 通过深层神经网络生成音乐

    深度学习改善了我们生活的许多方面,无论是明显的还是微妙的.深度学习在电影推荐系统.垃圾邮件检测和计算机视觉等过程中起着关键作用. 尽管围绕深度学习作为黑匣子和训练难度的讨论仍在进行,但在医学.虚拟助理 ...

  • keras搭建多层LSTM时间序列预测模型

    参考基于 Keras 的 LSTM 时间序列分析--以苹果股价预测为例 ######################导入库##########################import osos.e ...

  • 解析Transformer模型

    ❝ GiantPandaCV导语:这篇文章为大家介绍了一下Transformer模型,Transformer模型原本是NLP中的一个Idea,后来也被引入到计算机视觉中,例如前面介绍过的DETR就是将 ...

  • 免费课《Python快速入门》

    最近一段时间,我按照章节录制<python快速入门>,视频由22段小视频,总时长128分钟. 新视频教程特点: 使用jupyter notebook 剪辑视频配上欢快的bgm,学起来悦耳赏 ...

  • 最全Python快速入门教程,满满都是干货

    程序员大牛 Python是面向对象,高级语言,解释,动态和多用途编程语言.Python易于学习,而且功能强大,功能多样的脚本语言使其对应用程序开发具有吸引力. Python的语法和动态类型具有其解释性 ...

  • python快速入门的方法?python教程推荐

    方法不对,努力白费,这是很多初学者学python从入门到放弃的主要原因. 所以,你需要用它来,一个星期的时间,快速入门python编程,哪怕你是零基础. 在之前我也曾多次介绍过它,因为它确实很简单,确 ...

  • Python快速入门

    第一篇:计算机核心基础 01 计算机核心基础 附录1-cpu详解 第二篇:编程语言 01 编程语言与Python介绍 第三篇:Python语法入门 01 Python语法入门之变量 02 Python ...

  • 如何快速入门Python编程?这19个语法是第一站!

    Python编程学习圈 5天前 很多人听说Python编程简单易学,前景好薪酬高,所以就想快点入门Python编程,有方法吗?有套路吗?当然有,不过要快速入门Python编程,我觉得这19个语法是第一 ...

  • Python语法快速入门视频课程

    插播一条广告 Python数据挖掘与文本分析&Stata应用能力提升与实证前沿云特训 Python部分上课时间为6月29日-7月2日,感兴趣的童鞋欢迎关注 <Python语法快速入门&g ...

  • Python爬虫快速入门

    周末这两天我又接着之前的劲儿<Python快速入门>,将python爬虫相关的知识点做了梳理,录屏.剪辑.上传到B站. 由于在公共区域录制,偶尔会有点吵,不过95%上的时间音质是很不错的. ...

  • Python进阶之NumPy快速入门(一)在里面输入conda install numpy命令

    (在里面输入conda install numpy命令) https://m.toutiao.com/is/JT3FDcW/ 前言 NumPy是Python的一个扩展库,负责数组和矩阵运行.相较于传统 ...

  • Python爬虫入门,快速抓取大规模数据(第二部分)

    通过第一部分的练习,我们已经有了一个可运行的爬虫.这一部分我们详细的看看如何使用BeautifulSoup从网页中提取我们需要的数据,学习的目标是能够使用BeautifulSoup从网页中提取任意的数 ...