CenterNet学习

CenterNet算法

通常目标检测算法的做法是设定一堆先验框,对先验框进行调整获得预测框.

对Centernet网络而言,其将目标看作一个点,一个目标由一个特征点确定.

Centernet采用不同的方法,构建模型时将目标作为一个点——即目标BBox的中心点。

Centernet的检测器采用关键点估计来找到中心点,并回归到其他目标属性。

一. 预测部分

1.主干网络介绍

Centernet用到的主干特征网络有多种,一般是以Hourglass Network、DLANet或者Resnet为主干特征提取网络,由于centernet所用到的Hourglass Network参数量太大,有19000W参数,DLANet并没有keras资源,本文以Resnet为例子进行解析。

ResNet50有两个基本的块,分别名为Conv Block和Identity Block,其中Conv Block输入和输出的维度是不一样的,所以不能连续串联,它的作用是改变网络的维度;Identity Block输入维度和输出维度相同,可以串联,用于加深网络的。

Conv Block的结构如下:

Conv Block的代码实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):

    filters1, filters2, filters3 = filters

    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = Conv2D(filters1, (1, 1), strides=strides,
               name=conv_name_base + '2a', use_bias=False)(input_tensor)
    x = BatchNormalization(name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters2, kernel_size, padding='same',
               name=conv_name_base + '2b', use_bias=False)(x)
    x = BatchNormalization(name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c', use_bias=False)(x)
    x = BatchNormalization(name=bn_name_base + '2c')(x)

    shortcut = Conv2D(filters3, (1, 1), strides=strides,
                      name=conv_name_base + '1', use_bias=False)(input_tensor)
    shortcut = BatchNormalization(name=bn_name_base + '1')(shortcut)

    x = layers.add([x, shortcut])
    x = Activation('relu')(x)
    return x

Identity Block的结构如下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

def identity_block(input_tensor, kernel_size, filters, stage, block):

    filters1, filters2, filters3 = filters

    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = Conv2D(filters1, (1, 1), name=conv_name_base + '2a', use_bias=False)(input_tensor)
    x = BatchNormalization(name=bn_name_base + '2a')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters2, kernel_size,padding='same', name=conv_name_base + '2b', use_bias=False)(x)
    x = BatchNormalization(name=bn_name_base + '2b')(x)
    x = Activation('relu')(x)

    x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c', use_bias=False)(x)
    x = BatchNormalization(name=bn_name_base + '2c')(x)

    x = layers.add([x, input_tensor])
    x = Activation('relu')(x)
    return x

这两个都是残差网络结构 当输入图片是512x512x3的时候,整体的特征层shape变化为:

20200802125950594.png

ResNet网络实现

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

def ResNet50(inputs):
    # 512x512x3
    x = ZeroPadding2D((3, 3))(inputs)
    # 256,256,64
    x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1', use_bias=False)(x)
    x = BatchNormalization(name='bn_conv1')(x)
    x = Activation('relu')(x)

    # 256,256,64 -> 128,128,64
    x = MaxPooling2D((3, 3), strides=(2, 2), padding="same")(x)

    # 128,128,64 -> 128,128,256
    x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')

    # 128,128,256 -> 64,64,512
    x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')

    # 64,64,512 -> 32,32,1024
    x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
    x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')

    # 32,32,1024 -> 16,16,2048
    x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
    x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')

    return x


2.利用初步特征获得高分辨率特征图

利用上一步获得到的resnet50的最后一个特征层的shape为(16,16,2048)。

对于该特征层,centernet利用三次反卷积进行上采样,从而更高的分辨率输出. 三次反卷积的输出通道数分别为256,128,64.

每次反卷积,特征层的高和宽变为原来的两倍.特征层最后为128x128x64

利用该有效层获得最终的预测结果

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
x = Dropout(rate=0.5)(x)
#-------------------------------#
#   解码器
#-------------------------------#
num_filters = 256
# 16, 16, 2048  ->  32, 32, 256 -> 64, 64, 128 -> 128, 128, 64
for i in range(3):
    # 进行上采样
    x = Conv2DTranspose(num_filters // pow(2, i), (4, 4), strides=2, use_bias=False, padding='same',
                        kernel_initializer='he_normal',
                        kernel_regularizer=l2(5e-4))(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

3.Center Head从特征获取预测结果

这个特征层相当于将整个图片划分成128x128个区域,每个区域存在一个特征点,如果某个物体的中心落在这个区域,那么就由这个特征点来确定。 (某个物体的中心落在这个区域,则由这个区域左上角的特征点来约定)

我们可以利用这个特征层进行三个卷积,分别是:

1、热力图预测,此时卷积的通道数为num_classes,最终结果为(128,128,num_classes),代表每一个热力点是否有物体存在,以及物体的种类; 2、中心点预测,此时卷积的通道数为2,最终结果为(128,128,2),代表每一个物体中心距离热力点偏移的情况; 3、宽高预测,此时卷积的通道数为2,最终结果为(128,128,2),代表每一个物体宽高的预测情况;

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
    y1 = Conv2D(64, 3, padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(x)
    y1 = BatchNormalization()(y1)
    y1 = Activation('relu')(y1)
    y1 = Conv2D(num_classes, 1, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4), activation='sigmoid')(y1)

    # wh header
    y2 = Conv2D(64, 3, padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(x)
    y2 = BatchNormalization()(y2)
    y2 = Activation('relu')(y2)
    y2 = Conv2D(2, 1, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(y2)

    # reg header
    y3 = Conv2D(64, 3, padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(x)
    y3 = BatchNormalization()(y3)
    y3 = Activation('relu')(y3)
    y3 = Conv2D(2, 1, kernel_initializer='he_normal', kernel_regularizer=l2(5e-4))(y3)

4. 预测结果的解码

特征层相当于将图像划分为128x128个特征点每个特征点负责预测中心落在其右下角一片区域的物体.

解码的三个操作:

  1. 进行中心点便宜,利用reg中心点预测对特征点坐标进行偏移.
  2. 利用中心点和wh宽高,获得预测框的左上角和右下角
  3. 此时获得的预测框就可以绘制在图片上.

同时利用非极大一直操作,防止同一类的框堆积.

方法采用最大池化,利用3x3的池化核在热力图上搜索,然后只保留一定区域内得分最大的框

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

def nms(heat, kernel=3):
    hmax = MaxPooling2D((kernel, kernel), strides=1, padding='SAME')(heat)
    heat = tf.where(tf.equal(hmax, heat), heat, tf.zeros_like(heat))
    return heat

def topk(hm, max_objects=100):
    # hm -> Hot map热力图
    # 进行热力图的非极大抑制,利用3x3的卷积对热力图进行Max筛选,找出值最大的
    hm = nms(hm)
    # (b, h * w * c)
    b, h, w, c = tf.shape(hm)[0], tf.shape(hm)[1], tf.shape(hm)[2], tf.shape(hm)[3]
    # 将所有结果平铺,获得(b, h * w * c)
    hm = tf.reshape(hm, (b, -1))
    # (b, k), (b, k)
    scores, indices = tf.math.top_k(hm, k=max_objects, sorted=True)

    # 计算求出网格点,类别
    class_ids = indices % c
    xs = indices // c % w
    ys = indices // c // w
    indices = ys * w + xs
    return scores, indices, class_ids, xs, ys


def decode(hm, wh, reg, max_objects=100,num_classes=20):
    scores, indices, class_ids, xs, ys = topk(hm, max_objects=max_objects)
    # 获得batch_size
    b = tf.shape(hm)[0]
    
    # (b, h * w, 2)
    reg = tf.reshape(reg, [b, -1, 2])
    # (b, h * w, 2)
    wh = tf.reshape(wh, [b, -1, 2])
    length = tf.shape(wh)[1]

    # 找到其在1维上的索引
    batch_idx = tf.expand_dims(tf.range(0, b), 1)
    batch_idx = tf.tile(batch_idx, (1, max_objects))
    full_indices = tf.reshape(batch_idx, [-1]) * tf.to_int32(length) + tf.reshape(indices, [-1])
                    
    # 取出top_k个框对应的参数
    topk_reg = tf.gather(tf.reshape(reg, [-1,2]), full_indices)
    topk_reg = tf.reshape(topk_reg, [b, -1, 2])
    
    topk_wh = tf.gather(tf.reshape(wh, [-1,2]), full_indices)
    topk_wh = tf.reshape(topk_wh, [b, -1, 2])

    # 计算调整后的中心
    topk_cx = tf.cast(tf.expand_dims(xs, axis=-1), tf.float32) + topk_reg[..., 0:1]
    topk_cy = tf.cast(tf.expand_dims(ys, axis=-1), tf.float32) + topk_reg[..., 1:2]

    # (b,k,1) (b,k,1)
    topk_x1, topk_y1 = topk_cx - topk_wh[..., 0:1] / 2, topk_cy - topk_wh[..., 1:2] / 2
    # (b,k,1) (b,k,1)
    topk_x2, topk_y2 = topk_cx + topk_wh[..., 0:1] / 2, topk_cy + topk_wh[..., 1:2] / 2
    # (b,k,1)
    scores = tf.expand_dims(scores, axis=-1)
    # (b,k,1)
    class_ids = tf.cast(tf.expand_dims(class_ids, axis=-1), tf.float32)
    # (b,k,6)
    detections = tf.concat([topk_x1, topk_y1, topk_x2, topk_y2, scores, class_ids], axis=-1)
    return detections

5. 在原图上进行绘制

通过第三步,可以获得预测框在原图上的位置,而且这些预测框都是经过筛选的.这些筛选后的框可以直接绘制在图片上,就可以获得结果了.

二.训练部分

1.真实框的处理

既然在centernet中,物体的中心落在哪个特征点的右下角就由哪个特征点来负责预测,那在训练时就需要找到真实框和特征点之间的关系.

真实框和特征点之间的关系:

  1. 找到真实框的中心,通过真实框的中心找到其对应的特征点
  2. 根据真实框的种类,对网络应有的热力图进行设置,即heatmap热力图.其实就是对应的特征点里面的对应的种类.其中心值设置为1,然后这个特征点附近的其他特征点中该种类对应的值按照高斯分布不断下降.
  3. 需要设置特征点对应的reg中心点和wh宽高.
  4. 将预测结果和对应该有的预测结果进行对比,对网络进行反向梯度调整.
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class Generator(object):
    def __init__(self,batch_size,train_lines,val_lines,
                input_size,num_classes,max_objects=100):
        
        self.batch_size = batch_size
        self.train_lines = train_lines
        self.val_lines = val_lines
        self.input_size = input_size
        self.output_size = (int(input_size[0]/4) , int(input_size[1]/4))
        self.num_classes = num_classes
        self.max_objects = max_objects
        
    def get_random_data(self, annotation_line, input_shape, random=True, jitter=.3, hue=.1, sat=1.5, val=1.5, proc_img=True):
        '''r实时数据增强的随机预处理'''
        line = annotation_line.split()
        image = Image.open(line[0])
        iw, ih = image.size
        h, w = input_shape
        box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])

        # resize image
        new_ar = w/h * rand(1-jitter,1+jitter)/rand(1-jitter,1+jitter)
        scale = rand(0.25, 2)
        if new_ar < 1:
            nh = int(scale*h)
            nw = int(nh*new_ar)
        else:
            nw = int(scale*w)
            nh = int(nw/new_ar)
        image = image.resize((nw,nh), Image.BICUBIC)

        # place image
        dx = int(rand(0, w-nw))
        dy = int(rand(0, h-nh))
        new_image = Image.new('RGB', (w,h), (128,128,128))
        new_image.paste(image, (dx, dy))
        image = new_image

        # flip image or not
        flip = rand()<.5
        if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)

        # distort image
        hue = rand(-hue, hue)
        sat = rand(1, sat) if rand()<.5 else 1/rand(1, sat)
        val = rand(1, val) if rand()<.5 else 1/rand(1, val)
        x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV)
        x[..., 0] += hue*360
        x[..., 0][x[..., 0]>1] -= 1
        x[..., 0][x[..., 0]<0] += 1
        x[..., 1] *= sat
        x[..., 2] *= val
        x[x[:,:, 0]>360, 0] = 360
        x[:, :, 1:][x[:, :, 1:]>1] = 1
        x[x<0] = 0
        image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255


        # correct boxes
        box_data = np.zeros((len(box),5))
        if len(box)>0:
            np.random.shuffle(box)
            box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
            box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
            if flip: box[:, [0,2]] = w - box[:, [2,0]]
            box[:, 0:2][box[:, 0:2]<0] = 0
            box[:, 2][box[:, 2]>w] = w
            box[:, 3][box[:, 3]>h] = h
            box_w = box[:, 2] - box[:, 0]
            box_h = box[:, 3] - box[:, 1]
            box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
            box_data = np.zeros((len(box),5))
            box_data[:len(box)] = box
        if len(box) == 0:
            return image_data, []

        if (box_data[:,:4]>0).any():
            return image_data, box_data
        else:
            return image_data, []

    def generate(self, train=True):
        while True:
            if train:
                # 打乱
                shuffle(self.train_lines)
                lines = self.train_lines
            else:
                shuffle(self.val_lines)
                lines = self.val_lines
            batch_images = np.zeros((self.batch_size, self.input_size[0], self.input_size[1], self.input_size[2]), dtype=np.float32)

            batch_hms = np.zeros((self.batch_size, self.output_size[0], self.output_size[1], self.num_classes), dtype=np.float32)
            batch_whs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32)
            batch_regs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32)
            batch_reg_masks = np.zeros((self.batch_size, self.max_objects), dtype=np.float32)
            batch_indices = np.zeros((self.batch_size, self.max_objects), dtype=np.float32)
            
            b = 0
            for annotation_line in lines:  
                img,y=self.get_random_data(annotation_line,self.input_size[0:2])

                if len(y)!=0:
                    boxes = np.array(y[:,:4],dtype=np.float32)
                    boxes[:,0] = boxes[:,0]/self.input_size[1]*self.output_size[1]
                    boxes[:,1] = boxes[:,1]/self.input_size[0]*self.output_size[0]
                    boxes[:,2] = boxes[:,2]/self.input_size[1]*self.output_size[1]
                    boxes[:,3] = boxes[:,3]/self.input_size[0]*self.output_size[0]

                for i in range(len(y)):
                    bbox = boxes[i].copy()
                    bbox = np.array(bbox)
                    bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, self.output_size[1] - 1)
                    bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, self.output_size[0] - 1)
                    cls_id = int(y[i,-1])
                    
                    h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
                    if h > 0 and w > 0:
                        ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
                        ct_int = ct.astype(np.int32)
                        
                        # 获得热力图
                        radius = gaussian_radius((math.ceil(h), math.ceil(w)))
                        radius = max(0, int(radius))
                        batch_hms[b, :, :, cls_id] = draw_gaussian(batch_hms[b, :, :, cls_id], ct_int, radius)
                        
                        batch_whs[b, i] = 1. * w, 1. * h
                        # 计算中心偏移量
                        batch_regs[b, i] = ct - ct_int
                        # 将对应的mask设置为1,用于排除多余的0
                        batch_reg_masks[b, i] = 1
                        # 表示第ct_int[1]行的第ct_int[0]个。
                        batch_indices[b, i] = ct_int[1] * self.output_size[0] + ct_int[0]

                batch_images[b] = preprocess_image(img)
                b = b + 1
                if b == self.batch_size:
                    b = 0
                    yield [batch_images, batch_hms, batch_whs, batch_regs, batch_reg_masks, batch_indices], np.zeros((self.batch_size,))

                    batch_images = np.zeros((self.batch_size, self.input_size[0], self.input_size[1], 3), dtype=np.float32)

                    batch_hms = np.zeros((self.batch_size, self.output_size[0], self.output_size[1], self.num_classes),
                                        dtype=np.float32)
                    batch_whs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32)
                    batch_regs = np.zeros((self.batch_size, self.max_objects, 2), dtype=np.float32)
                    batch_reg_masks = np.zeros((self.batch_size, self.max_objects), dtype=np.float32)
                    batch_indices = np.zeros((self.batch_size, self.max_objects), dtype=np.float32)

2. 利用处理完的真实框与对应图片的预测结果计算loss

loss计算分为三个部分,分别是:

  1. 热力图loss
  2. reg中心点loss
  3. wh宽高loss

热力图loss采用focal loss的四新进行计算. reg中心点和wh宽高loss使用的是普通L1损失函数

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

def focal_loss(hm_pred, hm_true):
    # 找到正样本和负样本
    pos_mask = tf.cast(tf.equal(hm_true, 1), tf.float32)
    # 小于1的都是负样本
    neg_mask = tf.cast(tf.less(hm_true, 1), tf.float32)
    neg_weights = tf.pow(1 - hm_true, 4)

    pos_loss = -tf.log(tf.clip_by_value(hm_pred, 1e-7, 1.)) * tf.pow(1 - hm_pred, 2) * pos_mask
    neg_loss = -tf.log(tf.clip_by_value(1 - hm_pred, 1e-7, 1.)) * tf.pow(hm_pred, 2) * neg_weights * neg_mask

    num_pos = tf.reduce_sum(pos_mask)
    pos_loss = tf.reduce_sum(pos_loss)
    neg_loss = tf.reduce_sum(neg_loss)

    cls_loss = tf.cond(tf.greater(num_pos, 0), lambda: (pos_loss + neg_loss) / num_pos, lambda: neg_loss)
    return cls_loss


def reg_l1_loss(y_pred, y_true, indices, mask):
    b, c = tf.shape(y_pred)[0], tf.shape(y_pred)[-1]
    k = tf.shape(indices)[1]

    y_pred = tf.reshape(y_pred, (b, -1, c))
    length = tf.shape(y_pred)[1]
    indices = tf.cast(indices, tf.int32)

    # 找到其在1维上的索引
    batch_idx = tf.expand_dims(tf.range(0, b), 1)
    batch_idx = tf.tile(batch_idx, (1, k))
    full_indices = (tf.reshape(batch_idx, [-1]) * tf.to_int32(length) +
                    tf.reshape(indices, [-1]))
    # 取出对应的预测值
    y_pred = tf.gather(tf.reshape(y_pred, [-1,c]),full_indices)
    y_pred = tf.reshape(y_pred, [b, -1, c])

    mask = tf.tile(tf.expand_dims(mask, axis=-1), (1, 1, 2))
    # 求取l1损失值
    total_loss = tf.reduce_sum(tf.abs(y_true * mask - y_pred * mask))
    reg_loss = total_loss / (tf.reduce_sum(mask) + 1e-4)
    return reg_loss


def loss(args):
    #---------------------------------------------------------------------------
    # hm_pred:热力图的预测值       (self.batch_size, self.output_size[0], self.output_size[1], self.num_classes)
    # wh_pred:宽高的预测值         (self.batch_size, self.output_size[0], self.output_size[1], 2)
    # reg_pred:中心坐标偏移预测值  (self.batch_size, self.output_size[0], self.output_size[1], 2)
    # hm_true:热力图的真实值       (self.batch_size, self.output_size[0], self.output_size[1], self.num_classes)
    # wh_true:宽高的真实值         (self.batch_size, self.max_objects, 2)
    # reg_true:中心坐标偏移真实值  (self.batch_size, self.max_objects, 2)
    # reg_mask:真实值的mask        (self.batch_size, self.max_objects)
    # indices:真实值对应的坐标     (self.batch_size, self.max_objects)
    #---------------------------------------------------------------------------
    hm_pred, wh_pred, reg_pred, hm_true, wh_true, reg_true, reg_mask, indices = args
    hm_loss = focal_loss(hm_pred, hm_true)
    wh_loss = 0.1 * reg_l1_loss(wh_pred, wh_true, indices, reg_mask)
    reg_loss = reg_l1_loss(reg_pred, reg_true, indices, reg_mask)
    total_loss = hm_loss + wh_loss + reg_loss
    # total_loss = tf.Print(total_loss,[hm_loss,wh_loss,reg_loss])
    return total_loss


代码过程中关于tf的知识补充

tf.zero_like 对给定的张量中的元素设置为零. tf.equal 判定两个Tensor是否相等 tf.where 判定两个Tensor中元素相同为True 不同为False

tf.shape 获取Tensor尺寸 tf.reshape 改变Tensor的尺寸

tf.expand_dims 在对应处增加一个维度 tf.tile 按照矩阵复制粘贴 tf.cast 执行Tensor中的数据类型转

tf.concat 矩阵合并

Model subclassing

继承Model类 Keras中的所有模型都继承了Model类.

np.arange() 用于生成给定间隔的数组

updatedupdated2021-05-202021-05-20