Tensorflow2.0U-Net实现语义分割(城市景观数据集)
1.数据集
Cityscapes评测数据集在2015年由奔驰公司推动发布,是⽬前公认的机器视觉领域内最具权威性和专业性的图像分割数据集之⼀,其包含了5000张精细标注的图像和20000张粗略标注的图像,其中包含50个城市的不同场景、不同背景、不同街景,以及30类涵盖地⾯、建筑、交通标志、⾃然、天空、⼈和车辆等的物体标注。
城市街景图
对应的mask
选⽤数据集中以_gtFine_labelIds.png为结尾的图⽚作为语义分割的mask
2.代码
1.导⼊相应的库
import os
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import glob
import time
import pickle
import cv2
2.加载训练和验证图⽚
train_image_path ="./cityscapes/Cityspaces/images/train/*/*.png"
train_label_path ="./cityscapes/Cityspaces/gtFine/train/*/*_gtFine_labelIds.png"
val_image_path ="./cityscapes/Cityspaces/images/val/*/*.png"
val_label_path ="./cityscapes/Cityspaces/gtFine/val/*/*_gtFine_labelIds.png"
train_images = glob.glob(train_image_path)
train_labels = glob.glob(train_label_path)
val_images = glob.glob(val_image_path)
val_labels = glob.glob(val_label_path)
3.初始设置
BATCH_SIZE =32
BUFFER_SIZE =300
EPOCHS =60
train_count =len(train_images)
val_count =len(val_images)
train_step_per_epoch = train_count // BATCH_SIZE
val_step_per_epoch = val_count // BATCH_SIZE
auto = perimental.AUTOTUNE
4.创建训练集和验证集
#读取数据集中png图⽚函数
def read_png(path, channels=3):
img = ad_file(path)
img = tf.image.decode_png(img, channels=channels)
return img
#数据增强
def crop_img(img, label):
concat_img = tf.concat([img, label], axis=-1)
concat_img = size(concat_img,(280,280), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)    crop_img = tf.image.random_crop(concat_img,[256,256,4])
return crop_img[:,:,0:3], crop_img[:,:,3:]
#图像归⼀化
def normal(img, label):
img = tf.cast(img, tf.float32)/127.5-1
label = tf.cast(label, tf.int32)
return img, label
#加载训练集函数
def load_image_train(img_path, label_path):
img = read_png(img_path)
label = read_png(label_path, channels=1)
img, label = crop_img(img, label)
if tf.random.uniform(())>0.5:
import pickleimg = tf.image.flip_left_right(img)
label = tf.image.flip_left_right(label)
img, label = normal(img, label)
return img, label
#加载验证集函数
def load_image_val(img_path, label_path):
img = read_png(img_path)
label = read_png(label_path, channels=1)
img = size(img,(256,256))
label = size(label,(256,256))
img, label = normal(img, label)
return img, label
#将训练集图像和mask数据进⾏打乱
index = np.random.permutation(len(train_images))
train_images = np.array(train_images)[index]
train_labels = np.array(train_labels)[index]
#创建训练集和验证集
dataset_train = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
dataset_val = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
dataset_train = dataset_train.map(load_image_train, num_parallel_calls=auto)
dataset_val =dataset_val.map(load_image_val, num_parallel_calls=auto)
dataset_train = dataset_train.cache().repeat().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(auto) dataset_val = dataset_val.cache().batch(BATCH_SIZE)
5.抽取数据集中⼀张图⽚进⾏可视化
for image, label in dataset_train.take(1):
plt.figure(figsize=(10,10))
plt.subplot(121)
plt.title('image')
plt.imshow(tf.keras.preprocessing.image.array_to_img(image[0]))
plt.subplot(122)
plt.title('label')
plt.imshow(tf.keras.preprocessing.image.array_to_img(label[0]))
执⾏结果
左边是原始图⽚,右边是对应的mask标签
6.创建U-Net模型
def create_model():
inputs = tf.keras.layers.Input(shape=(256,256,3))
x = tf.keras.layers.Conv2D(64,3, padding='same', activation='relu')(inputs)    x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2D(64,3, padding='same', activation='relu')(x)
x = tf.keras.layers.BatchNormalization()(x)
#  x shape  (None, 256, 256, 64)
x1 = tf.keras.layers.MaxPooling2D(padding='same')(x)
# (None, 128, 128, 64)
x1 = tf.keras.layers.Conv2D(128,3, padding='same', activation='relu')(x1)    x1 = tf.keras.layers.BatchNormalization()(x1)
x1 = tf.keras.layers.Conv2D(128,3, padding='same', activation='relu')(x1)    x1= tf.keras.layers.BatchNormalization()(x1)
#  shape  (None, 128, 128, 128)
x2 = tf.keras.layers.MaxPooling2D(padding='same')(x1)
# shape N(one, 64, 64, 128)
x2 = tf.keras.layers.Conv2D(256,3, padding='same', activation='relu')(x2)    x2 = tf.keras.layers.BatchNormalization()(x2)
x2 = tf.keras.layers.Conv2D(256,3, padding='same', activation='relu')(x2)    x2= tf.keras.layers.BatchNormalization()(x2)
#  shape  (None, 64, 64, 256)
x3 = tf.keras.layers.MaxPooling2D(padding='same')(x2)
#    shape  (None, 32, 32, 256)
x3 = tf.keras.layers.Conv2D(512,3, padding='same', activation='relu')(x3)    x3 = tf.keras.layers.BatchNormalization()(x3)
x3 = tf.keras.layers.Conv2D(512,3, padding='same', activation='relu')(x3)    x3= tf.keras.layers.BatchNormalization()(x3)
#  shape  (None, 32, 32, 512)
x4 = tf.keras.layers.MaxPooling2D(padding='same')(x3)
#    shape  (None, 16, 16, 512)
x4 = tf.keras.layers.Conv2D(1024,3, padding='same', activation='relu')(x4)    x4 = tf.keras.layers.BatchNormalization()(x4)
x4 = tf.keras.layers.Conv2D(1024,3, padding='same', activation='relu')(x4)
x4 = tf.keras.layers.Conv2D(1024,3, padding='same', activation='relu')(x4)
x4= tf.keras.layers.BatchNormalization()(x4)
#  shape  (None, 16, 16, 1024)
# 上采样
x5 = tf.keras.layers.Conv2DTranspose(512,2, strides=2, padding='same',
activation='relu')(x4)
x5 = tf.keras.layers.BatchNormalization()(x5)
#  shape  (None, 32, 32, 512)
x6 = tf.concat([x3, x5], axis=-1)
#  shape  (None, 32, 32, 1024)
x6 = tf.keras.layers.Conv2D(512,3, padding='same', activation='relu')(x6)
x6 = tf.keras.layers.BatchNormalization()(x6)
x6 = tf.keras.layers.Conv2D(512,3, padding='same', activation='relu')(x6)
x6= tf.keras.layers.BatchNormalization()(x6)
#    (None, 32, 32, 512)
x7= tf.keras.layers.Conv2DTranspose(256,2, strides=2, padding='same',
activation='relu')(x6)
x7 = tf.keras.layers.BatchNormalization()(x7)
#  shape  (None, 64, 64, 256)
x8 = tf.concat([x2, x7], axis=-1)
#  (None, 64, 64, 512)
x8 = tf.keras.layers.Conv2D(256,3, padding='same', activation='relu')(x8)
x8 = tf.keras.layers.BatchNormalization()(x8)
x8 = tf.keras.layers.Conv2D(256,3, padding='same', activation='relu')(x8)
x8= tf.keras.layers.BatchNormalization()(x8)
#    (None, 64, 64, 256)
x9= tf.keras.layers.Conv2DTranspose(128,2, strides=2, padding='same',
activation='relu')(x8)
x9 = tf.keras.layers.BatchNormalization()(x9)
#  (None, 128, 128, 128)
x10 = tf.concat([x1, x9], axis=-1)
#  (None, 128, 128, 256)
x10 = tf.keras.layers.Conv2D(128,3, padding='same', activation='relu')(x10)
x10 = tf.keras.layers.BatchNormalization()(x10)
x10 = tf.keras.layers.Conv2D(128,3, padding='same', activation='relu')(x10)
x10 = tf.keras.layers.BatchNormalization()(x10)
#    (None, 128, 128, 128)
x11= tf.keras.layers.Conv2DTranspose(64,2, strides=2, padding='same',
activation='relu')(x10)
x11 = tf.keras.layers.BatchNormalization()(x11)
#  (None, 256, 256, 64)
x11 = tf.concat([x, x11], axis=-1)
#  (None, 256, 256, 128)
x12 = tf.keras.layers.Conv2D(64,3, padding='same', activation='relu')(x11)
x12 = tf.keras.layers.BatchNormalization()(x12)
x12 = tf.keras.layers.Conv2D(64,3, padding='same', activation='relu')(x12)
x12 = tf.keras.layers.BatchNormalization()(x12)
#    (None, 256, 256, 64)
output = tf.keras.layers.Conv2D(34,1, padding='same', activation='softmax')(x12) #  34 为这个数据集Label的类别数,  shape  (None, 256, 256, 34) ,
#最后就是算各个channel的最⼤,就是某⼀点所属的类别
return tf.keras.Model(inputs=inputs, outputs=output)
model = create_model()
model = create_model()
model.summary()
执⾏结果
Model:"functional_1"
__________________________________________________________________________________________________ Layer (type)                    Output Shape        Param #    Connected to
================================================================================================== input_1 (InputLayer)[(None,256,256,3)0
__________________________________________________________________________________________________ conv2d (Conv2D)(None,256,256,64)1792        input_1[0][0]
__________________________________________________________________________________________________ batch_normalization (BatchNorma (None,256,256,64)256        conv2d[0][0]
__________________________________________________________________________________________________ conv2d_1 (Conv2D)(None,256,256,64)36928      batch_normalization[0][0]
__________________________________________________________________________________________________ batch_normalization_1 (BatchNor (None,256,256,64)256        conv2d_1[0][0]
__________________________________________________________________________________________________ max_pooling2d (MaxPooling2D)(None,128,128,64)0          batch_normalization_1[0][0]
__________________________________________________________________________________________________ conv2d_2 (Conv2D)(None,128,128,12873856      max_pooling2d[0][0]
__________________________________________________________________________________________________ batch_normalization_2 (BatchNor (None,128,128,128512        conv2d_2[0][0]
__________________________________________________________________________________________________ conv2d_3 (Conv2D)(None,128,128,128147584      batch_normalization_2[0][0]
__________________________________________________________________________________________________ batch_normalization_3 (BatchNor (None,128,128,128512        conv2d_3[0][0]
__________________________________________________________________________________________________ max_pooling2d_1 (MaxPooling2D)(None,64,64,128)0          batch_normalization_3[0][0]
__________________________________________________________________________________________________ conv2d_4 (Conv2D)(None,64,64,256)295168      max_pooling2d_1[0][0]
__________________________________________________________________________________________________ batch_normalization_4 (BatchNor (None,64,64,256)1024        conv2d_4[0][0]
__________________________________________________________________________________________________ conv2d_5 (Conv2D)(None,64,64,256)590080      batch_normalization_4[0][0]
__________________________________________________________________________________________________ batch_normalization_5 (BatchNor (None,64,64,256)1024        conv2d_5[0][0]
__________________________________________________________________________________________________ max_pooling2d_2 (MaxPooling2D)(None,32,32,256)0          batch_normalization_5[0][0]
__________________________________________________________________________________________________ conv2d_6 (Conv2D)(None,32,32,512)1180160    max_pooling2d_2[0][0]
__________________________________________________________________________________________________ batch_normalization_6 (BatchNor (None,32,32,512)2048        conv2d_6[0][0]
__________________________________________________________________________________________________ conv2d_7 (Conv2D)(None,32,32,512)2359808    batch_normalization_6[0][0]
__________________________________________________________________________________________________ batch_normalization_7 (BatchNor (None,32,32,512)2048        conv2d_7[0][0]
__________________________________________________________________________________________________ max_pooling2d_3 (MaxPooling2D)(None,16,16,512)0          batch_normalization_7[0][0]
__________________________________________________________________________________________________ conv2d_8 (Conv2D)(None,16,16,1024)4719616    max_pooling2d_3[0][0]
__________________________________________________________________________________________________ batch_normalization_8 (BatchNor (None,16,16,1024)4096        conv2d_8[0][0]
__________________________________________________________________________________________________ conv2d_9 (Conv2D)(None,16,16,1024)9438208    batch_normalization_8[0][0]
__________________________________________________________________________________________________ batch_normalization_9 (BatchNor (None,16,16,1024)4096        conv2d_9[0][0]
__________________________________________________________________________________________________ conv2d_transpose (Conv2DTranspo (None,32,32,512)2097664    batch_normalization_9[0][0]
__________________________________________________________________________________________________ batch_normalization_10 (BatchNo (None,32,32,512)2048        conv2d_transpose[0][0]
__________________________________________________________________________________________________ tf_op_layer_concat (TensorFlowO [(None,32,32,10240          batch_normalization_7[0][0]
batch_normalization_10[0][0]
__________________________________________________________________________________________________

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。