基于PyTorch的深度学习6数据处理工具箱2
基于PyTorch的深度学习6——数据处理工具箱2
torchvision有4个功能模块:model、datasets、transforms和utils。
主要介绍如何使用datasets的ImageFolder处理自定义数据集,以及如何使用transforms对源数据进行预处理、增强等。下面将重点介绍transforms及ImageFolder。
transforms 提供了对PIL Image对象和Tensor对象的常用操作。
1)对PIL Image的常见操作如下。
• Scale/Resize:调整尺寸,长宽比保持不变。
• CenterCrop、RandomCrop、RandomSizedCrop:裁剪图片,CenterCrop和RandomCrop在crop时是固定size,RandomResizedCrop则是random size的crop。
• Pad:填充。
• ToTensor:把一个取值范围是[0,255]的PIL.Image转换成Tensor。形状为(H,W,C)的Numpy.ndarray转换成形状为[C,H,W],取值范围是[0,1.0]的torch.FloatTensor。
• RandomHorizontalFlip:图像随机水平翻转,翻转概率为0.5。
• RandomVerticalFlip:图像随机垂直翻转。
• ColorJitter:修改亮度、对比度和饱和度。
2)对Tensor的常见操作如下。
• Normalize:标准化,即,减均值,除以标准差。
• ToPILImage:将Tensor转为PIL Image。
如果要对数据集进行多个操作,可通过Compose将这些操作像管道一样拼接起来,类似于nn.Sequential。以下为示例代码:
transforms.Compose([
#将给定的 PIL.Image 进行中心切割,得到给定的 size,
#size 可以是 tuple,(target_height, target_width)。
#size 也可以是一个 Integer,在这种情况下,切出来的图片形状是正方形。
transforms.CenterCrop(10),
#切割中心点的位置随机选取
transforms.RandomCrop(20, padding=0),
#把一个取值范围是 [0, 255] 的 PIL.Image 或者 shape 为 (H, W, C) 的 numpy.ndarray,
#转换为形状为 (C, H, W),取值范围是 [0, 1] 的 torch.FloatTensor
transforms.ToTensor(),
#规范化到[-1,1]
transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
])
还可以自己定义一个Python Lambda表达式,如将每个像素值加10,可表示为:transforms.Lambda(lambda x:x.add(10))。更多内容可参考官网:https://PyTorch.org/docs/stable/torchvision/transforms.html。
我们可以利用torchvision.datasets. ImageFolder 来直接构造出dataset
loader = datasets.ImageFolder(path)
loader = data.DataLoader(dataset)
ImageFolder会将目录中的文件夹名自动转化成序列,当DataLoader载入时,标签自动就是整数序列了。下面我们利用ImageFolder读取不同目录下的图片数据,然后使用transforms进行图像预处理,预处理有多个,我们用compose把这些操作拼接在一起。然后使用DataLoader加载。对处理后的数据用torchvision.utils中的save_image保存为一个png格式文件,然后用Image.open打开该png文件,详细代码如下:
from torchvision import transforms, utils
from torchvision import datasets
import torch
import matplotlib.pyplot as plt
%matplotlib inline
my_trans=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
train_data = datasets.ImageFolder('./data/torchvision_data', transform=my_trans)
train_loader = data.DataLoader(train_data,batch_size=8,shuffle=True,)
for i_batch, img in enumerate(train_loader):
if i_batch == 0:
print(img[1])
fig = plt.figure()
grid = utils.make_grid(img[0])
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.show()
utils.save_image(grid,'test01.png')
break