[Pytorch]PyTorchDataloader⾃定义数据读取
整理⼀下看到的⾃定义数据读取的⽅法,较好的有⼀下三篇⽂章, 其实⾃定义的⽅法就是把现有数据集的train和test分别⽤ 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变。
之前刚开始⽤的时候,写Dataloader遇到不少坑。⽹上有⼀些教程 分为all images in one folder 和 each class one folder。后⾯的那种写的⼈⽐较多,我写⼀下前⾯的这种,程式化的东西,每次不同的任务改⼏个参数就好。
等训练的时候写⼀篇⽂章把2333
⼀.已有的东西
举例⼦:⽤kaggle上的⼀个dog breed的数据集为例。数据⽂件夹⾥⾯有三个⼦⽬录
test: ⼏千张图⽚,没有标签,测试集
train: 10222张狗的图⽚,全是jpg,⼤⼩不⼀,有长有宽,基本都在400×300以上
labels.csv : excel表格, 图⽚名称+品种名称
<img src="pic4.zhimg/v2-6128d817c09f05fe3bdfe05e1f84a92f_b.jpg" data-caption="" data-size="normal" data-rawwidth="496" data-rawheight="85" class="origin_image zh-lightbox-thumb" width="496" data-original="pic4.zhimg/v2-6128d817c09f05fe3bdfe05e1f84a92f_r.jpg">
我喜欢先⽤pandas把表格信息读出来看⼀看
import pandas as pd
import numpy as np
df = pd.read_csv('./dog_breed/labels.csv')
print(df.info())
print(df.head())
<img src="pic1.zhimg/v2-9d680235e5ff00c3f869a3eab4630ca4_b.jpg" data-caption="" d
ata-size="normal" data-rawwidth="731" data-rawheight="265" class="origin_image zh-lightbox-thumb" width="731" data-original="pic1.zhimg/v2-9d680235e5ff00c3f869a3eab4630ca4_r.jpg">
看到,⼀共有10222个数据,id对应的是图⽚的名字,但是没有后缀 .jpg。 breed对应的是⽝种。
⼆.预处理
我们要做的事情是:
1)得到⼀个长 list1 : ⾥⾯是每张图⽚的路径
2)另外⼀个长list2: ⾥⾯是每张图⽚对应的标签(整数),顺序要和list1对应。
3)把这两个list切分出来⼀部分作为验证集
1)看看⼀共多少个breed,把每种breed名称和⼀个数字编号对应起来:
from pandas import Series,DataFrame
breed = df['breed']
breed_np = Series.as_matrix(breed)
print(type(breed_np) )
print(breed_np.shape) #(10222,)
#看⼀下⼀共多少不同种类
breed_set = set(breed_np)
print(len(breed_set)) #120
#构建⼀个编号与名称对应的字典,以后输出的数字要变成名字的时候⽤:
breed_120_list = list(breed_set)
dic = {}
for i in range(120):
dic[ breed_120_list[i] ] = i
2)处理id那⼀列,分割成两段:
file = Series.as_matrix(df["id"])
print(file.shape)
import os
file = [i+".jpg" for i in file]
file = [os.path.join("./dog_breed/train",i) for i in file ]
file_train = file[:8000]
file_test = file[8000:]
print(file_train)
np.save( "file_train.npy" ,file_train )
np.save( "file_test.npy" ,file_test )
⾥⾯就是图⽚的路径了
<img src="pic3.zhimg/v2-b740e480301df1fded91c92090065736_b.jpg" data-caption="" data-size="normal" data-rawwidth="1076" data-rawheight="113" class="origin_image zh-lightbox-thumb" width="1076"
data-original="pic3.zhimg/v2-b740e480301df1fded91c92090065736_r.jpg">
3)处理breed那⼀列,分成两段:
breed = Series.as_matrix(df["breed"])
print(breed.shape)
number = []
for i in range(10222):
number.append( dic[ breed[i] ] )
number = np.array(number)
number_train = number[:8000]
number_test = number[8000:]
np.save( "number_train.npy" ,number_train )
np.save( "number_test.npy" ,number_test )
三.Dataloader
我们已经有了图⽚路径的list,target编号的list。填到Dataset类⾥⾯就⾏了。
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
#transforms.Scale(256),
#transforms.CenterCrop(224),
transforms.ToTensor(),
normalize
])
def default_loader(path):
img_pil = Image.open(path)
img_pil = size((224,224))
img_tensor = preprocess(img_pil)
return img_tensor
#当然出来的时候已经全都变成了tensor
class trainset(Dataset):
def __init__(self, loader=default_loader):
#定义好 image 的路径
self.images = file_train
self.target = number_train
self.loader = loader
def __getitem__(self, index):
fn = self.images[index]
img = self.loader(fn)
target = self.target[index]
return img,target
def __len__(self):
return len(self.images)
我们看⼀下代码,⾃定义Dataset只需要最下⾯⼀个class,继承⾃Dataset类。有三个私有函数
def __init__(self, loader=default_loader):
这个⾥⾯⼀般要初始化⼀个loader(代码见上⾯),⼀个images_path的列表,⼀个target的列表
def __getitem__(self, index):
这⾥吗就是在给你⼀个index的时候,你返回⼀个图⽚的tensor和target的tensor,使⽤了loader⽅法,经过 归⼀化,剪裁,类型转化,从图像变成tensor
def __len__(self):
return你所有数据的个数
这三个综合起来看呢,其实就是你告诉它你所有数据的长度,它每次给你返回⼀个shuffle过的index,以这个⽅式遍历数据集,通过
__getitem__(self, index)返回⼀组你要的(input,target)
四.使⽤
实例化⼀个dataset,然后⽤Dataloader 包起来
train_data = trainset()
trainloader = DataLoader(train_data, batch_size=4,shuffle=True)
<img src="pic1.zhimg/v2-8bbf753ce61d9a2cf0082b003b67d03c_b.jpg" data-caption="" data-size="normal" data-rawwidth="615" data-rawheight="181" class="origin_image zh-lightbox-thumb" width="615" data-original="pic1.zhimg/v2-8bbf753ce61d9a2cf0082b003b67d03c_r.jpg">
在上⼀篇博客中介绍了如何⽤PyTorch训练⼀个图像分类模型,建议先看懂那篇博客后再看这篇博客。在那份代码中,采⽤
torchvision.datasets.ImageFolder这个接⼝来读取图像数据,该接⼝默认你的训练数据是按照⼀个类
别存放在⼀个⽂件夹下。但是有些情况下你的图像数据不是这样维护的,⽐如⼀个⽂件夹下⾯各个类别的图像数据都有,同时⽤⼀个对应的标签⽂件,⽐如txt⽂件来维护图像和标签的对应关系,在这种情况下就不能⽤torchvision.datasets.ImageFolder来读取数据了,需要⾃定义⼀个数据读取接⼝。另外这篇博客最后还顺带介绍如何保存模型和多GPU训练。
怎么做呢?
先来看看torchvision.datasets.ImageFolder这个类是怎么写的,主要代码如下,想详细了解的可以看:。
看起来很复杂,其实⾮常简单。继承的类是torch.utils.data.Dataset,主要包含三个⽅法:初始化__init__,获取图像__getitem__,数据集数量 __len__。__init__⽅法中先通过find_classes函数得到分类的类别名(classes)和类别名与数字类别的映射关系字典
(class_to_idx)。然后通过make_dataset函数得到imags,这个imags是⼀个列表,其中每个值是⼀个tuple,每个tuple包含两个元素:图像路径和标签。剩下的就是⼀些赋值操作了。在__getitem__⽅法中最重要的就是 img = self.loader(path)这⾏,表⽰数据读取,可以从__init__⽅法中看出self.loader采⽤的是default_loader,这个default_loader的核⼼就是⽤python的PIL库的Image模块来读取图像数据。
class ImageFolder(data.Dataset):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
python怎么读文件夹下的文件夹root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
Attributes:
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
classes, class_to_idx = find_classes(root)
imgs = make_dataset(root, class_to_idx)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
< = root
self.imgs = imgs
self.classes = classes
self.class_to_idx = class_to_idx
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
path, target = self.imgs[index]
img = self.loader(path)
ansform is not None:
img = ansform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.imgs)
稍微看下default_loader函数,该函数主要分两种情况调⽤两个函数,⼀般采⽤pil_loader函数。
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论