⼿把⼿教你制作⾃⼰的CIFAR数据集(附项⽬源码)
从CIFAR数据集制作开始教你训练⾃⼰的分类模型
⽬录
参考CIFAR的格式制作⾃⼰的数据集
使⽤⾃⼰制作的数据集训练模型
参考CIFAR的格式制作⾃⼰的数据集
,记得给我留颗星星,下⾯是代码使⽤的详细教程
⾸先将所有图⽚按类别放在⽂件夹中,⽂件夹名为类别名。例如:存在20个类就分20个⽂件夹
将所有图⽚的路径提取到⼀个⽂件中,⽂件中每⾏包含图⽚路径和图⽚所属类别的索引(同时会⽣成图⽚类别和索引的对应关系)运⾏ get_filename.py ⽂件,⽣成图⽚路径+类别索引(data/cow_jpg.lst)和类别索引对应表(data/)
import os
def getFlist(path):
root_dirs =[]
for root, dirs, files in os.walk(path):
print('root_dir:', root)
print('sub_dirs:', dirs)
print('files:', files)
root_dirs.append(root)
print('root_dirs:', root_dirs[1:])
root_dirs = root_dirs[1:]
return root_dirs
def getChildList(root_dirs):
j =0
f =open('data/cow_jpg.lst','w')#⽣成⽂件路径和类别索引
if __name__ =='__main__':
resDir ='data'
f2 =open('data/','w')#⽣成类别和索引的对应表
root_dirs = getFlist(resDir)
k =0
for root_dir in root_dirs:
f2.write('%s %s\n'%(root_dir,k))
k = k+1
f2.close()
getChildList(root_dirs)python新建项目教程
print(root_dirs)
类别索引对应表就是将⽂件名所表⽰的类别与索引相对应,因为训练模型时不能以字符串作为类别名。例如:狗 0,猫 1,鸡3……
拆分训练集与测试集(这⾥我先将所有数据打乱,取⼀部分作为训练集,剩下的则作为测试集)
打开 make_test_batch.py ⽂件
import os
import random
f =open('data/cow_jpg.lst')#上⼀步⽣成的图⽚路径⽂件
list= f.readlines()
print(len(list))
random.shuffle(list)
print(list)
set_num =int(float(len(list))*0.2)
#0.2为拆分阈值,0.2则是前20%为测试集,剩下的是训练集
test_list =list[:set_num]
train_list =list[set_num:]
print('================')
print(len(test_list))
print(len(test_list))
print(len(train_list))
print(test_list and train_list)
f2 =open('data/cow_jpg_train.lst','w')
for i in train_list:
f2.write(i)
f3 =open('data/cow_jpg_test.lst','w')
for i in test_list:
f3.write(i)
f.close()
f2.close()
f3.close()
将上⼀步⽣成的图⽚路径+类别索引(data/cow_jpg.lst)⽂件填到第3⾏
设置拆分阈值,我设定了0.2为拆分阈值,其含义为前20%为测试集,剩下的是训练集
运⾏!
⽣成 test_batch (⽂件名:data/cow_jpg_test.lst)
和 train_batch (⽂件名:data/cow_jpg_train.lst)
制作CIFAR的batch
运⾏ demo.py ⽂件,其中将file_list参数填写上⼀步⽣成的图⽚路径⽂件(data/cow_jpg_train.lst)
填写拆分出的tran⽂件(cow_jpg_train.lst),bin_num 设置为4,就会⽣成四个batch_train⽂件
填写拆分出的tran⽂件(cow_jpg_test.lst),bin_num 设置为1,就会⽣成⼀个batch_test⽂件
共⽣成五个⽂件,与管⽅提供的CIFAR压缩包相同
制作CIFAR的.mate⽂件
在batch⽂件夹中新建⼀个a⽂件
打开 edit_mate.py ⽂件
填写每个batch包含的样本数量(num_cases_per_batch) ,这⾥我设置了2500因为我⼀共有10000个样
本,分了四个batch 将类别索引表⽂件()中的类别名(⽂件名不是索引)按顺序替换到第⼗⾏,应该是类别名,也就是⽂件名,例如:猫,狗,鸡,我这⾥是数字字符串
运⾏!
⽣成 a ⽂件
这样你就会得到:data_batch_0,…,test_a等三类⽂件,与官⽅的CIFAR数据集完全⼀致,下⾯我们以任何⼀个使⽤CIFAR数据集的模型为例,进⾏测试
使⽤⾃⼰制作的数据集训练模型
打开data_utils.py,到下⾯这段代码,将下载设置为否(download=False),不到就算了,跳过这步
if args.dataset =="cifar10":
trainset = datasets.CIFAR10(root="./data",
train=True,
download=False,
transform=transform_train)
testset = datasets.CIFAR10(root="./data",
train=False,
download=False,
transform=transform_test)if args.local_rank in[-1,0]else None
跑⼀下模型 train.py ,报错
Traceback (most recent call last):
File "/workspace/ViT-pytorch-main/train.py", line 347, in <module>
main()
File "/workspace/ViT-pytorch-main/train.py", line 342, in main
train(args, model)
File "/workspace/ViT-pytorch-main/train.py", line 158, in train
train_loader, test_loader = get_loader(args)
File "/workspace/ViT-pytorch-main/utils/data_utils.py", line 31, in get_loader
transform=transform_train)
File "/opt/conda/envs/ViT/lib/python3.6/site-packages/torchvision/datasets/cifar.py", line 93, in __init__
self._load_meta()
File "/opt/conda/envs/ViT/lib/python3.6/site-packages/torchvision/datasets/cifar.py", line 99, in _load_meta      ' You can use download=True to download it'
RuntimeError: Dataset metadata file not found or corrupted. You can use download=True to download it 原因是模型没有到CIFAR⽂件,因为CIFAR函数中⾃带完整性验证(check_integrity),关闭即可。关闭CIFAR源码中的⽂件完整性验证
根据报错信息到CIFAR源码的位置
我这⾥是/opt/conda/envs/ViT/lib/python3.6/site-packages/torchvision/datasets/cifar.py
打开源码,注释掉如下⼏⾏
if not check_integrity(path, a['md5']):
raise RuntimeError('Dataset metadata file not found or corrupted.' +
' You can use download=True to download it')
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
if not check_integrity(fpath, md5):
return False
到你的程序代码中的num_classes = 10,将10修改为你的类别数量
pycharm中可以 crtl + shift + F 搜索 num_classes
运⾏成功!
cifar10Dataset-master

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。