概述
在训练机器学习模型中,一般会应用mini-batch的方法来遍历数据集,此外还需要对加载的数据集进行归一化等附加操作。如果能够构建可迭代的数据集模型,并且添加对数据集的附加操作处理接口,那么数据集的使用将非常方便。例如Pytorch提供的对数据集的非常优雅的处理方式:定义附加操作集合,加载数据集。这样即可非常优雅的遍历整个数据集。
接下来是对上述代码的源码学习,我们分析Pytorch提供的对数据集处理的源代码,来进一步的理解这种方式。
构建生成式数据集
Dataset
抽象类Dataset,子类要实现方法:__len__
, __getitem__
,使得Dataset对象支持len()和切片操作。
pytorch此模块官方代码见link
Sampler
要能以iter的形式获取dataset中的数据,需要采样器sampler,这里介绍SequentialSampler和RandomSampler。
首先来看Sampler,实现__iter__
和__len__
使得支持iter()和len()函数。
SequentialSampler为顺序采样,按照原顺序遍历数据集,采样返回数据集的下标(若数据集长度为3,则依次迭代返回0,1,2)。
RandomSampler为随机采样,采样返回数据集的下标的随机排序(若数据集长度为3,则迭代返回可能为[(0,1,2),(0,2,1),(1,0,2),(1,2,0),(2,0,1),(2,1,0)]。
源代码中还有SubsetRandomSampler和WeightedRandomSampler,有兴趣的同学点这里link
DataLoader && DataLoaderIter
将dataset和sampler组合在一起,构成dataloader,为让其支持迭代操作,还需定义DataLoaderIter(在此省略多进程的部分)。
DataLoader
DataLoaderIter
以上的处理完成了可迭代的生成式数据集的构建完整代码见link,接下来构建mnist的dataset。
MNIST Dateset
MNIST数据集继承上述介绍的Dataset类,需要实现__len__
和__getitem__
方法:这里的__len__
方法取决于MNIST数据集,60000的训练集和10000的测试集;getitem
方法使得数据集支持切片操作,并且在此可完成transform等附加操作。此处完整代码见torchvision提供的源码。
transfrom模块中定义了许多预处理操作,这里介绍几个简单常用的。
Compose
相当于是一个容器,顺序执行容器内部的操作。
ToTenser
Pytorch最常用的方法,将numpy.ndarray (H x W x C)[0, 255] 转成tensor (C x H x W) [0.0, 1.0]。
Normalize
Normalize是归一化常用的操作,具体见代码__doc__
总结
现在我们再来看这一段代码,是不是感觉特别的清晰:定义预处理操作ToTensor、Normalize,将数据限制到[-1, 1]范围;将mnist的dataset传递给dataloader构造loader生成模式,这样就能够在训练代码中使用for··in··语句直接加载数据。