『计算机视觉』Mask-RCNN_训练⽹络其⼆:train⽹络结构损
失函数
resizedGithub地址:
⼀、training⽹络简介
流程和inference⼤部分⼀致,在下图中我们将之前inference就介绍过的分类、回归和掩码⽣成流程压缩到⼀个块中,以便其他部分更为清晰。⽽两者主要不同之处为:
1. ⽹络输⼊:输⼊tensor增加到了7个之多(图上画出的6个以及image_meta),⼤部分是计算Loss的标签前置
2. 损失函数:添加了5个损失函数,2个⽤于RPN计算,2个⽤于最终分类回归instance,1个⽤于掩码损失计算
3. 原始标签处理:推理⽹络中,Proposeal筛选出来的rpn_rois直接⽤于⽣成分类回归以及掩码信息,⽽training中这些候选区需
要和图⽚标签信息进⾏运算,⽣成有训练价值的输出,进⾏后⾯的⽣成以及损失函数计算
⾸先初始化并载⼊预训练参数(下节会介绍本部分相关操作),
然后经由下⾯⼏⾏代码,即可进⾏训练,
⽹络输⼊
build函数在train⽅法中被调⽤(),涉及巨多预处理函数设计,需要的时候⾃⾏进⼊train⽅法查看(更确切的说是在data_generator⽅法,由train调⽤),
- images: [batch, H, W, C]
- image_meta: [batch, (meta data)] Image details. See compose_image_meta()
- rpn_match: [batch, N] Integer (1=positive anchor, -1=negative, 0=neutral)
- rpn_bbox: [batch, N, (dy, dx, log(dh), log(dw))] Anchor bbox deltas.
- gt_class_ids: [batch, MAX_GT_INSTANCES] Integer class IDs
- gt_boxes: [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2)]
- gt_masks: [batch, height, width, MAX_GT_INSTANCES]. The height and width
are those of the image unless use_mini_mask is True, in which
case they are defined in MINI_MASK_SHAPE.
原始标签处理
然后我们在开篇流程图中标注了⼀个名为"检测⽬标处理"的框,对应代码如下:
# Generate detection targets
# Subsamples proposals and generates target outputs for training
# Note that proposal class IDs, gt_boxes, and gt_masks are zero
# padded. Equally, returned rois and targets are zero padded.
rois, target_class_ids, target_bbox, target_mask =\
DetectionTargetLayer(config, name="proposal_targets")([
target_rois, input_gt_class_ids, gt_boxes, input_gt_masks])
其⽬的是将原始的图像信息input和proposal们进⾏计算融合,输出可以⽤于Loss计算的标准的格式,⽂档很清晰,
"""Subsamples proposals and generates target box refinement, class_ids,
and masks for each.
Inputs:
proposals: [batch, N, (y1, x1, y2, x2)] in normalized coordinates. Might
be zero padded if there are not enough proposals.
gt_class_ids:[batch, MAX_GT_INSTANCES] Integer class IDs.
gt_boxes: [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2)] in normalized
coordinates.
gt_masks: [batch, height, width, MAX_GT_INSTANCES] of boolean type
Returns: Target ROIs and corresponding class IDs, bounding box shifts,
and masks.
rois: [batch, TRAIN_ROIS_PER_IMAGE, (y1, x1, y2, x2)] in normalized
coordinates
target_class_ids: [batch, TRAIN_ROIS_PER_IMAGE]. Integer class IDs.
target_deltas: [batch, TRAIN_ROIS_PER_IMAGE, (dy, dx, log(dh), log(dw)]
target_mask: [batch, TRAIN_ROIS_PER_IMAGE, height, width]
Masks cropped to bbox boundaries and resized to neural
network output size.
Note: Returned arrays might be zero padded if not enough target ROIs.
"""
这个处理之后,结构同inference中的介绍,
mrcnn_class_logits, mrcnn_class, mrcnn_bbox =\
fpn_classifier_graph(rois, mrcnn_feature_maps, input_image_meta,
config.POOL_SIZE, config.NUM_CLASSES,
train_bn=config.TRAIN_BN,
fc_layers_size=config.FPN_CLASSIF_FC_LAYERS_SIZE)
mrcnn_mask = build_fpn_mask_graph(rois, mrcnn_feature_maps,
input_image_meta,
config.MASK_POOL_SIZE,
config.NUM_CLASSES,
train_bn=config.TRAIN_BN)
损失函数
然后就是损失函数了(浩浩荡荡10来⾏……),注意output_rois这⼀⾏,我们之前就提过,keras中接收tf的Tensor只能作为class的初始化参数,⽽不能作为⽹络数据流,所以这⾥加了⼀层封装,
output_rois = KL.Lambda(lambda x: x * 1, name="output_rois")(rois)
# Losses
rpn_class_loss = KL.Lambda(lambda x: rpn_class_loss_graph(*x), name="rpn_class_loss")(
[input_rpn_match, rpn_class_logits])
rpn_bbox_loss = KL.Lambda(lambda x: rpn_bbox_loss_graph(config, *x), name="rpn_bbox_loss")(
[input_rpn_bbox, input_rpn_match, rpn_bbox])
class_loss = KL.Lambda(lambda x: mrcnn_class_loss_graph(*x), name="mrcnn_class_loss")(
[target_class_ids, mrcnn_class_logits, active_class_ids])
bbox_loss = KL.Lambda(lambda x: mrcnn_bbox_loss_graph(*x), name="mrcnn_bbox_loss")(
[target_bbox, target_class_ids, mrcnn_bbox])
mask_loss = KL.Lambda(lambda x: mrcnn_mask_loss_graph(*x), name="mrcnn_mask_loss")(
[target_mask, target_class_ids, mrcnn_mask])
⼆、损失函数简介
RPN分类损失
我们先看⼀下RPN真实标签⽣成函数中的⼀段注释,
# Match anchors to GT Boxes
# If an anchor overlaps a GT box with IoU >= 0.7 then it's positive.
# If an anchor overlaps a GT box with IoU < 0.3 then it's negative.
# Neutral anchors are those that don't match the conditions above,
# and they don't influence the loss function.
# However, don't keep any GT box unmatched (rare, but happens). Instead,
# match it to the closest anchor (even if its max IoU is < 0.3).
然后看本损失函数,
def rpn_class_loss_graph(rpn_match, rpn_class_logits):
"""RPN anchor classifier loss.
rpn_match: [batch, anchors, 1]. Anchor match type. 1=positive,
-1=negative, 0=neutral anchor.
rpn_class_logits: [batch, anchors, 2]. RPN classifier logits for FG/BG.
"""
# Squeeze last dim to simplify
rpn_match = tf.squeeze(rpn_match, -1)
# Get anchor classes. Convert the -1/+1 match to 0/1 values.
anchor_class = K.cast(K.equal(rpn_match, 1), tf.int32)
# Positive and Negative anchors contribute to the loss,
# but neutral anchors (match value = 0) don't.
indices = tf._equal(rpn_match, 0))
# Pick rows that contribute to the loss and filter out the rest.
rpn_class_logits = tf.gather_nd(rpn_class_logits, indices)
anchor_class = tf.gather_nd(anchor_class, indices)
# Cross entropy loss
loss = K.sparse_categorical_crossentropy(target=anchor_class,
output=rpn_class_logits,
from_logits=True)
loss = K.switch(tf.size(loss) > 0, K.mean(loss), tf.constant(0.0))
return loss
真实标签有{1, 0, -1}三种,logits结果在0~1分布,⽽在RPN分类结果中,真实标签为0的anchors不参与损失函数的构建,所以我们将标签为0的真实标签剔除,然后将-1标签转换为0进⾏交叉熵计算。
RPN回归损失
def rpn_bbox_loss_graph(config, target_bbox, rpn_match, rpn_bbox):
"""Return the RPN bounding box loss graph.
config: the model config object.
target_bbox: [batch, max positive anchors, (dy, dx, log(dh), log(dw))].
Uses 0 padding to fill in unsed bbox deltas.
rpn_match: [batch, anchors, 1]. Anchor match type. 1=positive,
-1=negative, 0=neutral anchor.
rpn_bbox: [batch, anchors, (dy, dx, log(dh), log(dw))]
"""
# input_rpn_bbox, input_rpn_match, rpn_bbox
# Positive anchors contribute to the loss, but negative and
# neutral anchors (match value of 0 or -1) don't.
rpn_match = K.squeeze(rpn_match, -1) # [batch, anchors]
indices = tf.where(K.equal(rpn_match, 1))
# Pick bbox deltas that contribute to the loss
rpn_bbox = tf.gather_nd(rpn_bbox, indices) # [n, 4]
# Trim target bounding box deltas to the same length as rpn_bbox.
batch_counts = K.sum(K.cast(K.equal(rpn_match, 1), tf.int32), axis=1) # 1标签数⽬
# target_bbox: [batch, max positive anchors, (dy, dx, log(dh), log(dw))]
# rpn_match: [batch]
target_bbox = batch_pack_graph(target_bbox, batch_counts,
config.IMAGES_PER_GPU)
loss = smooth_l1_loss(target_bbox, rpn_bbox)
loss = K.switch(tf.size(loss) > 0, K.mean(loss), tf.constant(0.0))
return loss
仅仅真实标签为1的类参与回归运算,
1. 对于target_bbox,虽然对每张图⽚其框数⼀致且和rpn_match的第⼆维度相等,但是对于图⽚i只有前⾯的N i个框有意义(⽽不是和
anchors⼀⼀对应的),后⾯为0填充,N i值等于对应图⽚的rpn_match等于1的数⽬
2. (推测)target_bbox中bbox坐标的排列顺序等于rpn_match中的标识顺序,所以使⽤rpn_match索引出rpn_bbox对应1的位置后
直接和target_bbox的前N i运算即可
def batch_pack_graph(x, counts, num_rows):
"""Picks different number of values from each row
in x depending on the values in counts.
x: [batch, max positive anchors, (dy, dx, log(dh), log(dw))]
counts: [batch]
"""
outputs = []
for i in range(num_rows):
outputs.append(x[i, :counts[i]])
at(outputs, axis=0)
损失函数使⽤smooth_l1_loss(坐标回归都⽤这个?)
def smooth_l1_loss(y_true, y_pred):
"""Implements Smooth-L1 loss.
y_true and y_pred are typically: [N, 4], but could be any shape.
"""
diff = K.abs(y_true - y_pred)
less_than_one = K.cast(K.less(diff, 1.0), "float32")
loss = (less_than_one * 0.5 * diff**2) + (1 - less_than_one) * (diff - 0.5)
return loss
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系QQ:729038198,我们将在24小时内删除。
发表评论