from keras_resnet.models import ResNet18, ResNet34, ResNet50 from keras.layers import Input, Conv2DTranspose, BatchNormalization, ReLU, Conv2D, Lambda, MaxPooling2D, Dropout from keras.layers import UpSampling2D, Concatenate from keras.models import Model from keras.initializers import normal, constant, zeros from keras.regularizers import l2 import keras.backend as K import tensorflow as tf from losses import loss def nms(heat, kernel=3): hmax = tf.nn.max_pool2d(heat, (kernel, kernel), strides=1, padding='SAME') heat = tf.where(tf.equal(hmax, heat), heat, tf.zeros_like(heat)) return heat def topk(hm, max_objects=100): hm = nms(hm) # (b, h * w * c) b, h, w, c = tf.shape(hm)[0], tf.shape(hm)[1], tf.shape(hm)[2], tf.shape(hm)[3] # hm2 = tf.transpose(hm, (0, 3, 1, 2)) # hm2 = tf.reshape(hm2, (b, c, -1)) hm = tf.reshape(hm, (b, -1)) # (b, k), (b, k) scores, indices = tf.nn.top_k(hm, k=max_objects) # scores2, indices2 = tf.nn.top_k(hm2, k=max_objects) # scores2 = tf.reshape(scores2, (b, -1)) # topk = tf.nn.top_k(scores2, k=max_objects) class_ids = indices % c xs = indices // c % w ys = indices // c // w indices = ys * w + xs return scores, indices, class_ids, xs, ys def evaluate_batch_item(batch_item_detections, num_classes, max_objects_per_class=20, max_objects=100, iou_threshold=0.5, score_threshold=0.1): batch_item_detections = tf.boolean_mask(batch_item_detections, tf.greater(batch_item_detections[:, 4], score_threshold)) detections_per_class = [] for cls_id in range(num_classes): # (num_keep_this_class_boxes, 4) score 大于 score_threshold 的当前 class 的 boxes class_detections = tf.boolean_mask(batch_item_detections, tf.equal(batch_item_detections[:, 5], cls_id)) nms_keep_indices = tf.image.non_max_suppression(class_detections[:, :4], class_detections[:, 4], max_objects_per_class, iou_threshold=iou_threshold) class_detections = K.gather(class_detections, nms_keep_indices) detections_per_class.append(class_detections) batch_item_detections = K.concatenate(detections_per_class, axis=0) def filter(): # nonlocal batch_item_detections _, indices = tf.nn.top_k(batch_item_detections[:, 4], k=max_objects) batch_item_detections_ = tf.gather(batch_item_detections, indices) return batch_item_detections_ def pad(): # nonlocal batch_item_detections batch_item_num_detections = tf.shape(batch_item_detections)[0] batch_item_num_pad = tf.maximum(max_objects - batch_item_num_detections, 0) batch_item_detections_ = tf.pad(tensor=batch_item_detections, paddings=[ [0, batch_item_num_pad], [0, 0]], mode='CONSTANT', constant_values=0.0) return batch_item_detections_ batch_item_detections = tf.cond(tf.shape(batch_item_detections)[0] >= 100, filter, pad) return batch_item_detections def decode(hm, wh, reg, max_objects=100, nms=True, num_classes=20, score_threshold=0.1): scores, indices, class_ids, xs, ys = topk(hm, max_objects=max_objects) b = tf.shape(hm)[0] # (b, h * w, 2) reg = tf.reshape(reg, (b, -1, tf.shape(reg)[-1])) # (b, h * w, 2) wh = tf.reshape(wh, (b, -1, tf.shape(wh)[-1])) # (b, k, 2) topk_reg = tf.gather(reg, indices, batch_dims=1) # (b, k, 2) topk_wh = tf.cast(tf.gather(wh, indices, batch_dims=1), tf.float32) topk_cx = tf.cast(tf.expand_dims(xs, axis=-1), tf.float32) + topk_reg[..., 0:1] topk_cy = tf.cast(tf.expand_dims(ys, axis=-1), tf.float32) + topk_reg[..., 1:2] scores = tf.expand_dims(scores, axis=-1) class_ids = tf.cast(tf.expand_dims(class_ids, axis=-1), tf.float32) topk_x1 = topk_cx - topk_wh[..., 0:1] / 2 topk_x2 = topk_cx + topk_wh[..., 0:1] / 2 topk_y1 = topk_cy - topk_wh[..., 1:2] / 2 topk_y2 = topk_cy + topk_wh[..., 1:2] / 2 # (b, k, 6) detections = tf.concat([topk_x1, topk_y1, topk_x2, topk_y2, scores, class_ids], axis=-1) if nms: detections = tf.map_fn(lambda x: evaluate_batch_item(x[0], num_classes=num_classes, score_threshold=score_threshold), elems=[detections], dtype=tf.float32) return detections def centernet(num_classes, backbone='resnet50', input_size=512, max_objects=100, score_threshold=0.1, nms=True): assert backbone in ['resnet18', 'resnet34', 'resnet50'] output_size = input_size // 4 image_input = Input(shape=(None, None, 3)) hm_input = Input(shape=(output_size, output_size, num_classes)) wh_input = Input(shape=(max_objects, 2)) reg_input = Input(shape=(max_objects, 2)) reg_mask_input = Input(shape=(max_objects,)) index_input = Input(shape=(max_objects,)) if backbone == 'resnet18': resnet = ResNet18(image_input, include_top=False, freeze_bn=True) elif backbone == 'resnet34': resnet = ResNet34(image_input, include_top=False, freeze_bn=True) else: resnet = ResNet50(image_input, include_top=False, freeze_bn=True) # C5 (b, 16, 16, 512) C2, C3, C4, C5 = resnet.outputs C5 = Dropout(rate=0.5)(C5) C4 = Dropout(rate=0.4)(C4) C3 = Dropout(rate=0.3)(C3) C2 = Dropout(rate=0.2)(C2) x = C5 # decoder x = Conv2D(256, 1, padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(UpSampling2D()(x)) x = BatchNormalization()(x) x = ReLU()(x) x = Concatenate()([C4, x]) x = Conv2D(256, 3, padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(x) x = BatchNormalization()(x) # (b, 32, 32, 512) x = ReLU()(x) x = Conv2D(128, 1, padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(UpSampling2D()(x)) x = BatchNormalization()(x) x = ReLU()(x) x = Concatenate()([C3, x]) x = Conv2D(128, 3, padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(x) x = BatchNormalization()(x) # (b, 64, 64, 128) x = ReLU()(x) x = Conv2D(64, 1, padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(UpSampling2D()(x)) x = BatchNormalization()(x) x = ReLU()(x) x = Concatenate()([C2, x]) x = Conv2D(64, 3, padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(x) x = BatchNormalization()(x) # (b, 128, 128, 512) x = ReLU()(x) # hm header y1 = Conv2D(64, 3, padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(x) y1 = BatchNormalization()(y1) y1 = ReLU()(y1) y1 = Conv2D(num_classes, 1, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4), activation='sigmoid')(y1) # wh header y2 = Conv2D(64, 3, padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(x) y2 = BatchNormalization()(y2) y2 = ReLU()(y2) y2 = Conv2D(2, 1, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(y2) # reg header y3 = Conv2D(64, 3, padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(x) y3 = BatchNormalization()(y3) y3 = ReLU()(y3) y3 = Conv2D(2, 1, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(y3) loss_ = Lambda(loss, name='centernet_loss')( [y1, y2, y3, hm_input, wh_input, reg_input, reg_mask_input, index_input]) model = Model(inputs=[image_input, hm_input, wh_input, reg_input, reg_mask_input, index_input], outputs=[loss_]) # detections = decode(y1, y2, y3) detections = Lambda(lambda x: decode(*x, max_objects=max_objects, score_threshold=score_threshold, nms=nms, num_classes=num_classes))([y1, y2, y3]) prediction_model = Model(inputs=image_input, outputs=detections) debug_model = Model(inputs=image_input, outputs=[y1, y2, y3]) return model, prediction_model, debug_model