采⽤SegmentationTransformer(SETR)(Pytorch版本)训
练C。。。
官⽅的Segmentation Transformer源码是基于MMSegmentation框架的,不便于阅读和学习,想使⽤官⽅版本的就不⽤参考此博客了。
这⾥采⽤的是GitHub上某⼤佬复现Segmentation Transformer的版本
那么开始调整源码吧
⼀、下载Segmentation Transformer 源码
将源码下载好后,解压⾄DeeplabV3⽂件夹下
⼆、修改datasets.py
按照上⼀篇Deeplabv3博客处理好CityScapes数据集的label
由于SETR模型设计了三种decoder结构 这⾥采⽤的是最简单的Naive结构,这⾥采⽤的是SETR_Naive_S⽹络模型,如下,查看源码可以看出CityScapes数据集⽤于训练的图像⼤⼩为768*768,
⾸先将类别数修改为20
然后就需要datasets.py部分
⾸先修改DatasetTrain这个类
对于数据增强部分,我只保留了随机翻转,其余的 randomly scale the img and the label部分和random
crop from the img and label 我进⾏了注释,你也可以根据⾃⼰的需要调整,但是要保证返回的图像的⼤⼩是768*768
同样地对于DatasetVal这个类
同样地对于DatasetSeq这个类
datasets.py代码如下:
# camera-ready
import torch
import torch.utils.data
import numpy as np
import cv2
import os
train_dirs = ["jena/", "zurich/", "weimar/", "ulm/", "tubingen/", "stuttgart/",
"strasbourg/", "monchengladbach/", "krefeld/", "hanover/",
"hamburg/", "erfurt/", "dusseldorf/", "darmstadt/", "cologne/",
"bremen/", "bochum/", "aachen/"]
val_dirs = ["frankfurt/", "munster/", "lindau/"]
test_dirs = ["berlin", "bielefeld", "bonn", "leverkusen", "mainz", "munich"]
class DatasetTrain(torch.utils.data.Dataset):
def __init__(self, cityscapes_data_path, cityscapes_meta_path):
self.img_dir = cityscapes_data_path + "/leftImg8bit/train/"
self.label_dir = cityscapes_meta_path + "/label_imgs/"
self.img_h = 1024
self.img_w = 2048
# w_img_h = 512
# w_img_w = 1024
for train_dir in train_dirs:
train_img_dir_path = self.img_dir + train_dir
file_names = os.listdir(train_img_dir_path)
for file_name in file_names:
img_id = file_name.split("_leftImg8bit.png")[0]
img_path = train_img_dir_path + file_name
label_img_path = self.label_dir + img_id + ".png"
example = {}
example["img_path"] = img_path
example["label_img_path"] = label_img_path
example["img_id"] = img_id
self.num_examples = amples)
def __getitem__(self, index):
example = amples[index]
img_path = example["img_path"]
img = cv2.imread(img_path, -1) # (shape: (1024, 2048, 3))
# resize img without interpolation (want the image to still match
# label_img, which we resize below):
img = size(img, (w_img_w, w_img_h),
interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024, 3))
label_img_path = example["label_img_path"]
label_img = cv2.imread(label_img_path, -1) # (shape: (1024, 2048))
# resize label_img without interpolation (want the resulting image to
# still only contain pixel values corresponding to an object class):
label_img = size(label_img, (w_img_w, w_img_h),
interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024))
# flip the img and the label with 0.5 probability:
flip = np.random.randint(low=0, high=2)
if flip == 1:
img = cv2.flip(img, 1)
label_img = cv2.flip(label_img, 1)
>>>>>>>>>>>>>>##        # randomly scale the img and the label:
>>>>>>>>>>>>>>##        # scale = np.random.uniform(low=0.7, high=2.0)
# new_img_h = int(w_img_h)
# new_img_w = int(w_img_w)
# resize img without interpolation (want the image to still match
# label_img, which we resize below):
# img = size(img, (new_img_w, new_img_h),
#                  interpolation=cv2.INTER_NEAREST) # (shape: (new_img_h, new_img_w, 3))
# resize label_img without interpolation (want the resulting image to
# still only contain pixel values corresponding to an object class):
# label_img = size(label_img, (new_img_w, new_img_h),
#                        interpolation=cv2.INTER_NEAREST) # (shape: (new_img_h, new_img_w))        >>>>>>>>>>>>>>##
# # # # # # # # debug visualization START
# print (scale)
# print (new_img_h)
# print (new_img_w)
#
# cv2.imshow("test", img)
# cv2.waitKey(0)
#
#
# cv2.imshow("test", label_img)
# cv2.waitKey(0)
# # # # # # # # debug visualization END
>>>>>>>>>>>>>>##        # select a 256x256 random crop from the img and label:
>>>>>>>>>>>>>>##        # start_x = np.random.randint(low=0, high=(new_img_w - 256))
# end_x = start_x + 256
# start_y = np.random.randint(low=0, high=(new_img_h - 256))
# end_y = start_y + 256
# start_x = np.random.randint(low=0, high=(new_img_w - 768))
# end_x = start_x + 768
# start_y = np.random.randint(low=0, high=(new_img_h - 768))
# end_y = start_y + 768
# img = img[start_y:end_y, start_x:end_x] # (shape: (256, 256, 3))
# label_img = label_img[start_y:end_y, start_x:end_x] # (shape: (256, 256))
>>>>>>>>>>>>>>##
# # # # # # # # debug visualization START
# print (img.shape)
# print (label_img.shape)
#
# cv2.imshow("test", img)
# cv2.waitKey(0)
#
# cv2.imshow("test", label_img)
# cv2.waitKey(0)
# # # # # # # # debug visualization END
# normalize the img (with the mean and std for the pretrained ResNet):
img = img/255.0
img = img - np.array([0.485, 0.456, 0.406])
resize函数c++
img = img/np.array([0.229, 0.224, 0.225]) # (shape: (256, 256, 3))
img = np.transpose(img, (2, 0, 1)) # (shape: (3, 256, 256))
img = img.astype(np.float32)
# convert numpy -> torch:
img = torch.from_numpy(img) # (shape: (3, 256, 256))
label_img = torch.from_numpy(label_img) # (shape: (256, 256))
return (img, label_img)
def __len__(self):
return self.num_examples
class DatasetVal(torch.utils.data.Dataset):
def __init__(self, cityscapes_data_path, cityscapes_meta_path):
self.img_dir = cityscapes_data_path + "/leftImg8bit/val/"
self.label_dir = cityscapes_meta_path + "/label_imgs/"
self.img_h = 1024
self.img_w = 2048
for val_dir in val_dirs:
val_img_dir_path = self.img_dir + val_dir
file_names = os.listdir(val_img_dir_path)
for file_name in file_names:
img_id = file_name.split("_leftImg8bit.png")[0]
img_path = val_img_dir_path + file_name
label_img_path = self.label_dir + img_id + ".png"
label_img = cv2.imread(label_img_path, -1) # (shape: (1024, 2048))
example = {}
example["img_path"] = img_path
example["label_img_path"] = label_img_path
example["img_id"] = img_id
self.num_examples = amples)
def __getitem__(self, index):
example = amples[index]
img_id = example["img_id"]
img_path = example["img_path"]
img = cv2.imread(img_path, -1) # (shape: (1024, 2048, 3))
# resize img without interpolation (want the image to still match
# label_img, which we resize below):
img = size(img, (w_img_w, w_img_h),
interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024, 3))
label_img_path = example["label_img_path"]
label_img = cv2.imread(label_img_path, -1) # (shape: (1024, 2048))
# resize label_img without interpolation (want the resulting image to
# still only contain pixel values corresponding to an object class):
label_img = size(label_img, (w_img_w, w_img_h),
interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024))
# # # # # # # # debug visualization START
# cv2.imshow("test", img)
# cv2.waitKey(0)
#
# cv2.imshow("test", label_img)
# cv2.waitKey(0)
# # # # # # # # debug visualization END
# normalize the img (with the mean and std for the pretrained ResNet):
img = img/255.0
img = img - np.array([0.485, 0.456, 0.406])
img = img/np.array([0.229, 0.224, 0.225]) # (shape: (512, 1024, 3))
img = np.transpose(img, (2, 0, 1)) # (shape: (3, 512, 1024))
img = img.astype(np.float32)
# convert numpy -> torch:
img = torch.from_numpy(img) # (shape: (3, 512, 1024))
label_img = torch.from_numpy(label_img) # (shape: (512, 1024))
return (img, label_img, img_id)
def __len__(self):
return self.num_examples
class DatasetSeq(torch.utils.data.Dataset):
def __init__(self, cityscapes_data_path, cityscapes_meta_path, sequence):
self.img_dir = cityscapes_data_path + "/leftImg8bit/demoVideo/stuttgart_" + sequence + "/"        # self.img_dir = cityscapes_data_path + "/leftImg8bit/" + sequence + "/"
self.img_h = 1024
self.img_w = 2048

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。

发表评论