[pytorch]FixMatch 代码详解(超详细)
FixMatch 代码详解-训练过程
上⼀篇⼤概讲了数据加载的过程,这⼀篇更进⼀步,分析⼀下训练是怎样进⾏的
上⼀篇链接:
参数
所有的参数我都默认使⽤作者给出的例⼦:
其运⾏时每个参数的值如下:
然后我们将这些参数带⼊,看看每⼀步是怎样运⾏的.
数据产⽣
⾸先,是产⽣带标签和不带标签数据的索引,其在cifar.py⽂件中的代码分析见上篇
结果如下,不带标签的数据使⽤了所有的数据,⽽带标签的数据经过数据扩增之后为68000个
让我们看⼀下图⽚的变化
⾸先,是不带任何变化的原始数据图像:python train .py --dataset cifar10 --num -labeled 4000 --arch wideresnet --batch -size 64 --lr 0.03 --expand -labels --seed 5 --out results /cifar10@4000.51INFO - __main__ -  {'T': 1, 'amp': False , 'arch': 'wideresnet', 'batch_size': 64, 'dataset': 'cifar10', 'device': device (type='cuda', index =0), 'ema_decay'1base_dataset = datasets .CIFAR10(        './CIFAR10', train =True , download =True )labels = base_dataset .targets label_per_class = 4000 // 10labels = np .array (labels )labeled_idx = []# unlabeled data: all data (github/kekmodel/FixMatch-pytorch/issues/10)unlabeled_idx = np .array (range (len (labels )))for  i in  range (10):    idx = np .where (labels == i )[0]    idx = np .random .choice (idx , label_per_class , False )    labeled_idx .extend (idx )labeled_idx = np .array (labeled_idx )print ('number labeled_idx =',len (labeled_idx ))assert  len (labeled_idx ) == 4000if  True  or  4000 < 64:    num_expand_x = math .ceil (        64 * 1024 / 4000)  #16.384 = 17    labeled_idx = np .hstack ([labeled_idx for  _ in  range (num_expand_x )])np .random .shuffle (labeled_idx )print ('number labeled_idx = ',len (labeled_idx ))print ('number unlabeled_idx =', len (unlabeled_idx ))train_labeled_idxs = labeled_idx train_unlabeled_idxs = unlabeled_idx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25number labeled_idx = 4000number labeled_idx =  68000number unlabeled_idx = 50000
1
2
3
然后,我们使⽤不带数据增强的变化,也就是作者对验证集使⽤的图像变化. ToTensor()能够把灰度范围从0-255变换到0-1之间,⽽后⾯的transform.Normalize()则把0-1变换到(-1,1). 注意图⽚⼤⼩没有变化,只是我截图的时候放⼤了图⽚.
然后我们看看带数据的图⽚所使⽤的数据增强(两次)train_labeled_dataset = CIFAR10SSL (        './data', train_labeled_idxs , train =True ,        transform =transforms .ToTensor ())train_iter = iter (train_labeled_dataset )
1
2
3
4# 可视化⽅法,重复执⾏可得到不同的图⽚数据imgs , label = next (train_iter )print (image .size ) # (32, 32)image = transforms .ToPILImage ()(imgs ).convert ('RGB')image .show ()print (label )
1
transform和convert的区别2
3
4
5
6cifar10_mean = (0.4914, 0.4822, 0.4465)cifar10_std = (0.2471, 0.2435, 0.2616)transform_val = transforms .Compose ([    transforms .ToTensor (),    transforms .Normalize (mean =cifar10_mean , std =cifar10_std )])train_labeled_dataset = CIFAR10SSL (        './data', train_labeled_idxs , train =True ,        transform =transform_val )train_iter = iter (train_labeled_dataset )imgs , label = next (train_iter )print (image .size ) # (32, 32)image = transforms .ToPILImage ()(imgs ).convert ('RGB')image .show ()print (label )
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
对于不带数据的标签,我们有两种数据增强,弱增强和强增强. 强增强操作在论⽂中的描述.cifar10_mean = (0.4914, 0.4822, 0.4465)cifar10_std = (0.2471, 0.2435, 0.2616)transform_labeled = transforms .Compose ([    transforms .RandomHorizontalFlip (), #Horizontally flip the given image randomly with a given probability.    transforms .RandomCrop (size =32,                          padding =int (32*0.125),                          padding_mode ='reflect'),    transforms .ToTensor (),    transforms .Normalize (mean =cifar10_mean , std =cifar10_std )])train_labeled_dataset = CIFAR10SSL (        './data', train_labeled_idxs , train =True ,        transform =transform_labeled )train_iter = iter (train_labeled_dataset )imgs , label = next (train_iter )print (image .size ) # (32, 32)image = transforms .ToPILImage ()(imgs ).convert ('RGB')image .show ()print (label ) # 2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class  TransformFixMatch (object ):    def  __init__(self , mean , std ):        self .weak = transforms .Compose ([            transforms .RandomHorizontalFlip (),            transforms .RandomCrop (size =32,                                  padding =int (32*0.125),                                  padding_mode ='reflect')])        self .strong = transforms .Compose ([            transforms .RandomHorizontalFlip (),            transforms .Ra
ndomCrop (size =32,                                  padding =int (32*0.125),                                  padding_mode ='reflect'),            RandAugmentMC (n =2, m =10)])        self .normalize = transforms .Compose ([            transforms .ToTensor (),            transforms .Normalize (mean =mean , std =std )])    def  __call__(self , x ):        weak = self .weak (x )        strong = self .strong (x )        return  self .normalize (weak ), self .normalize (strong )# 强增强的操作。在randaugment.py ⽂件中def  fixmatch_augment_pool ():    # FixMatch paper    augs = [(AutoContrast , None , None ),            (Brightness , 0.9, 0.05),            (Color , 0.9, 0.05),            (Contrast , 0.9, 0.05),            (Equalize , None , None ),            (Identity , None , None ),            (Posterize , 4, 4),            (Rotate , 30, 0),            (Sharpness , 0.9, 0.05),            (ShearX , 0.3, 0),            (ShearY , 0.3, 0),            (Solarize , 256, 0),            (TranslateX , 0.3, 0),            (TranslateY , 0.3, 0)]    return  augs    class  RandAugmentMC (object ):    def  __init__(self , n , m ):        assert  n >= 1        assert  1 <= m <= 10        self .n = n        self .m = m        self .augment_pool = fixmatch_augment_pool ()    def  __call__(self , img ):        ops = random .choices (self .augment_pool , k =self .n )        for  op , max_v , bias in  ops :            v = np .random .randint (1, self .m )            if  random .random () < 0.5:                img = op (img , v =v , max_v =max_v , bias =bias )        img = CutoutAbs (img , int (32*0.5))        return  img
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
弱增强的图像结果(两次):
强增强的结果(运⾏四次):
所以,产⽣的带标签/不带标签/验证集的dataset类及dataloader如下:cifar10_mean =(0.4914,0.4822,0.4465)
cifar10_std =(0.2471,0.2435,0.2616)
train_labeled_dataset = CIFAR10SSL(
'./data', train_labeled_idxs, train=True,
transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))
train_iter =iter(train_labeled_dataset)
(inputs_u_w, inputs_u_s), _ =next(train_iter)
print(inputs_u_s.size)# (32, 32)
image = transforms.ToPILImage()(inputs_u_s).convert('RGB')
image.show()
1
2
3
4
5
6
7
8
9
10
11
12
labeled_dataset = CIFAR10SSL(
'./data', train_labeled_idxs, train=True,
transform=transform_labeled)
# len = 68000
unlabeled_dataset = CIFAR10SSL(
'./data', train_unlabeled_idxs, train=True,
transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))
# len = 50000
test_dataset = datasets.CIFAR10(
'./data', train=False, transform=transform_val, download=False)
# len = 10000
1
2
3
4
5
6
7
8
9
10
11

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。