[pytorch]FixMatch 代码详解(超详细)
FixMatch 代码详解-训练过程
上⼀篇⼤概讲了数据加载的过程,这⼀篇更进⼀步,分析⼀下训练是怎样进⾏的
上⼀篇链接:
参数
所有的参数我都默认使⽤作者给出的例⼦:
其运⾏时每个参数的值如下:
然后我们将这些参数带⼊,看看每⼀步是怎样运⾏的.
数据产⽣
⾸先,是产⽣带标签和不带标签数据的索引,其在cifar.py⽂件中的代码分析见上篇
结果如下,不带标签的数据使⽤了所有的数据,⽽带标签的数据经过数据扩增之后为68000个
让我们看⼀下图⽚的变化
⾸先,是不带任何变化的原始数据图像:python train .py --dataset cifar10 --num -labeled 4000 --arch wideresnet --batch -size 64 --lr 0.03 --expand -labels --seed 5 --out results /cifar10@4000.51INFO - __main__ - {'T': 1, 'amp': False , 'arch': 'wideresnet', 'batch_size': 64, 'dataset': 'cifar10', 'device': device (type='cuda', index =0), 'ema_decay'1base_dataset = datasets .CIFAR10( './CIFAR10', train =True , download =True )labels = base_dataset .targets label_per_class = 4000 // 10labels = np .array (labels )labeled_idx = []# unlabeled data: all data (github/kekmodel/FixMatch-pytorch/issues/10)unlabeled_idx = np .array (range (len (labels )))for i in range (10): idx = np .where (labels == i )[0] idx = np .random .choice (idx , label_per_class , False ) labeled_idx .extend (idx )labeled_idx = np .array (labeled_idx )print ('number labeled_idx =',len (labeled_idx ))assert len (labeled_idx ) == 4000if True or 4000 < 64: num_expand_x = math .ceil ( 64 * 1024 / 4000) #16.384 = 17 labeled_idx = np .hstack ([labeled_idx for _ in range (num_expand_x )])np .random .shuffle (labeled_idx )print ('number labeled_idx = ',len (labeled_idx ))print ('number unlabeled_idx =', len (unlabeled_idx ))train_labeled_idxs = labeled_idx train_unlabeled_idxs = unlabeled_idx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25number labeled_idx = 4000number labeled_idx = 68000number unlabeled_idx = 50000
1
2
3
然后,我们使⽤不带数据增强的变化,也就是作者对验证集使⽤的图像变化. ToTensor()能够把灰度范围从0-255变换到0-1之间,⽽后⾯的transform.Normalize()则把0-1变换到(-1,1). 注意图⽚⼤⼩没有变化,只是我截图的时候放⼤了图⽚.
然后我们看看带数据的图⽚所使⽤的数据增强(两次)train_labeled_dataset = CIFAR10SSL ( './data', train_labeled_idxs , train =True , transform =transforms .ToTensor ())train_iter = iter (train_labeled_dataset )
1
2
3
4# 可视化⽅法,重复执⾏可得到不同的图⽚数据imgs , label = next (train_iter )print (image .size ) # (32, 32)image = transforms .ToPILImage ()(imgs ).convert ('RGB')image .show ()print (label )
1
transform和convert的区别2
3
4
5
6cifar10_mean = (0.4914, 0.4822, 0.4465)cifar10_std = (0.2471, 0.2435, 0.2616)transform_val = transforms .Compose ([ transforms .ToTensor (), transforms .Normalize (mean =cifar10_mean , std =cifar10_std )])train_labeled_dataset = CIFAR10SSL ( './data', train_labeled_idxs , train =True , transform =transform_val )train_iter = iter (train_labeled_dataset )imgs , label = next (train_iter )print (image .size ) # (32, 32)image = transforms .ToPILImage ()(imgs ).convert ('RGB')image .show ()print (label )
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
对于不带数据的标签,我们有两种数据增强,弱增强和强增强. 强增强操作在论⽂中的描述.cifar10_mean = (0.4914, 0.4822, 0.4465)cifar10_std = (0.2471, 0.2435, 0.2616)transform_labeled = transforms .Compose ([ transforms .RandomHorizontalFlip (), #Horizontally flip the given image randomly with a given probability. transforms .RandomCrop (size =32, padding =int (32*0.125), padding_mode ='reflect'), transforms .ToTensor (), transforms .Normalize (mean =cifar10_mean , std =cifar10_std )])train_labeled_dataset = CIFAR10SSL ( './data', train_labeled_idxs , train =True , transform =transform_labeled )train_iter = iter (train_labeled_dataset )imgs , label = next (train_iter )print (image .size ) # (32, 32)image = transforms .ToPILImage ()(imgs ).convert ('RGB')image .show ()print (label ) # 2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class TransformFixMatch (object ): def __init__(self , mean , std ): self .weak = transforms .Compose ([ transforms .RandomHorizontalFlip (), transforms .RandomCrop (size =32, padding =int (32*0.125), padding_mode ='reflect')]) self .strong = transforms .Compose ([ transforms .RandomHorizontalFlip (), transforms .Ra
ndomCrop (size =32, padding =int (32*0.125), padding_mode ='reflect'), RandAugmentMC (n =2, m =10)]) self .normalize = transforms .Compose ([ transforms .ToTensor (), transforms .Normalize (mean =mean , std =std )]) def __call__(self , x ): weak = self .weak (x ) strong = self .strong (x ) return self .normalize (weak ), self .normalize (strong )# 强增强的操作。在randaugment.py ⽂件中def fixmatch_augment_pool (): # FixMatch paper augs = [(AutoContrast , None , None ), (Brightness , 0.9, 0.05), (Color , 0.9, 0.05), (Contrast , 0.9, 0.05), (Equalize , None , None ), (Identity , None , None ), (Posterize , 4, 4), (Rotate , 30, 0), (Sharpness , 0.9, 0.05), (ShearX , 0.3, 0), (ShearY , 0.3, 0), (Solarize , 256, 0), (TranslateX , 0.3, 0), (TranslateY , 0.3, 0)] return augs class RandAugmentMC (object ): def __init__(self , n , m ): assert n >= 1 assert 1 <= m <= 10 self .n = n self .m = m self .augment_pool = fixmatch_augment_pool () def __call__(self , img ): ops = random .choices (self .augment_pool , k =self .n ) for op , max_v , bias in ops : v = np .random .randint (1, self .m ) if random .random () < 0.5: img = op (img , v =v , max_v =max_v , bias =bias ) img = CutoutAbs (img , int (32*0.5)) return img
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
弱增强的图像结果(两次):
强增强的结果(运⾏四次):
所以,产⽣的带标签/不带标签/验证集的dataset类及dataloader如下:cifar10_mean =(0.4914,0.4822,0.4465)
cifar10_std =(0.2471,0.2435,0.2616)
train_labeled_dataset = CIFAR10SSL(
'./data', train_labeled_idxs, train=True,
transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))
train_iter =iter(train_labeled_dataset)
(inputs_u_w, inputs_u_s), _ =next(train_iter)
print(inputs_u_s.size)# (32, 32)
image = transforms.ToPILImage()(inputs_u_s).convert('RGB')
image.show()
1
2
3
4
5
6
7
8
9
10
11
12
labeled_dataset = CIFAR10SSL(
'./data', train_labeled_idxs, train=True,
transform=transform_labeled)
# len = 68000
unlabeled_dataset = CIFAR10SSL(
'./data', train_unlabeled_idxs, train=True,
transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))
# len = 50000
test_dataset = datasets.CIFAR10(
'./data', train=False, transform=transform_val, download=False)
# len = 10000
1
2
3
4
5
6
7
8
9
10
11
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论