详解飞桨框架数据管道
什么是数据管道?
在完成深度学习领域的任务时,我们最先面临的挑战就是数据处理,即需要将数据处理成模型能够"看懂"的语言,从而进行模型的训练。比如,在图像分类任务中,我们需要按格式处理图像数据与其对应的标签,然后才能将其输入到模型中,开始训练。在这个过程中,我们需要将图片数据从jpg、png或其它格式转换为numpy array的格式,然后对其进行一些加工,如重置大小、旋转变换、改变亮度等等,从而进行数据增强。所以,数据的预处理和加载方式很大程度上决定了模型最终的性能水平。传统框架常常包含着复杂的数据加载模式,多重的预处理操作常常会劝退许多人。而飞桨框架为了简化数据管道的流程,对数据管道相关的场景进行了流量交易高级封装,通过非常少量代码,即可实现数据的处理,更愉快的进行深度学习模型研发。
数据管道详解
在数据管道总共包含5个模块,分别是飞桨框架内置数据集、自定义数据集、数据增强、数据采样以及数据加载5个部分。
2.1 内置数据集
内置数据集介绍:
为了节约大家处理数据时所耗费的时间和精力,飞桨框架将一些我们常用到的数据集作为领域API对用户进行开放,用户通过调用paddle.vision.datasets和paddle.text.datasets即可直接使用领域API,这两个API内置包含了许多CV和NLP领域相关的常见数据集,具体如下:
import paddleimport numpy as np
paddle.__version__'2.0.0-rc1'
print('视觉相关数据集:', paddle.vision.datasets.__all__)
print('自然语言相关数据集:', paddle.text.datasets.__all__)
视觉相关数据集: ['DatasetFolder', 'ImageFolder', 'MNIST', 'FashionMNIST', 'Flowers', 'Cifar10', 'Cifar100', 'VOC2012']
自然语言相关数据集: ['Conll05st', 'Imdb', 'Imikolov', 'Movielens', 'UCIHousing', 'WMT14', 'WMT16']
内置数据集使用:
为了方便大家理解,这里我演示一下如何使用内置的手写数字识别的数据集,其他数据集的使用方式也类似,大家可以动手试一下哦。具体可以见下面的代码,注意,我们通过使用mode参数用来标识训练集与测试集。调用数据集接口后,相应的API会自动下载数据集到本机缓存目录~/.cache/paddle/dataset。
import paddle.vision as vision
print("训练集下载中...")# 训练数据集
train_dataset = vision.datasets.MNIST(mode='train')print("训练集下载完成!")print("测试集下载中...")# 验证数据集
test_dataset = vision.datasets.MNIST(mode='test')print("测试集下载完成!")
训练集下载中...
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-images-idx3-ubyte.gz
Begin to download
Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-labels-idx1-ubyte.gz
Begin to download
........
Download finished
训练集下载完成!
测试集下载中...
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-images-idx3-ubyte.gz
Begin to download
Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-labels-idx1-ubyte.gz
Begin to download
..
Download finished
测试集下载完成!
内置数据集可视化:
通过上面的步骤,我们就定义好了训练集与测试集,接下来,让我们来看一下数据集的内容吧。
import numpy as np
import matplotlib.pyplot as plt
train_data_0, train_label_0 = np.array(train_dataset[0][0]), train_dataset[0][1]
train_data_0 = train_data_0.reshape([28, 28])
plt.figure(figsize=(2, 2))
plt.imshow(train_data_0, cmap=plt.cm.binary)
print('train_data0 label is: ' + str(train_label_0))
train_data0 label is: [5]
从上例中可以看出,train_dataset 是一个 map-style 式的数据集,我们可以通过下标直接获取单个样本的图像数据与标签,从而进行可视化。
Note: map-style 是指可以通过下标的方式来获取指定样本,除此之外,还有 iterable-style 式的数据集,只能通过迭代的方式来获取样本,具体说明可以见下一节。
2.2 数据集定义
有同学提出虽然飞桨框架提供了许多领域数据集供我们使用,但是在实际的使用场景中,如果我们需要使用已有的数据来训练模型怎么办呢?别慌,飞桨也贴心地准备了 map-style 的 paddle.io.Dataset 基类 和 iterable-style 的 paddle.io.IterableDataset 基类 ,来完成数据集定义。此外,针对一些特殊的场景,飞桨框架也提供了 padde.io.TensorDataset 基类,可以直接处理 Tensor 数据为 dataset,一键完成数据集的定义。
让我们来看一下它们的使用方式吧~
paddle.io.Dataset的使用方式:
这个是我们最推荐使用的API,来完成数据的定义。使用 paddleio.Dataset,最后会返回一个 map-style 的 Dataset 类。可以用于后续的数据增强、数据加载等。而使用 padde.io.Dataset 也非常简单,只需要按格式完成以下四步即可。