FairMOT中的deque怎么存ReID特征的?

一时失言乱红尘 2023-02-25 02:26 6阅读 0赞

1.FairMOT代码逻辑分析

个人理解根据相关代码逻辑连起来,在下面做了注释。

  1. # opts.py
  2. # 这里是track_buffer参数默认值
  3. self.parser.add_argument('--track_buffer', type=int, default=30, help='tracking buffer')
  4. # multitracker.py
  5. # 在这里有两个类的实现,JDETracker类里初始化了一个STrack类型的列表,这个过程中会初始化保存特征的deque,设置deque大小。
  6. from collections import deque
  7. class JDETracker(object):
  8. self.features = deque([], maxlen=buffer_size)
  9. def update(self, im_blob, img0):
  10. if len(dets) > 0:
  11. '''Detections'''
  12. detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for
  13. (tlbrs, f) in zip(dets[:, :5], id_feature)]
  14. else:
  15. detections = []
  16. ...
  17. output_stracks = [track for track in self.tracked_stracks if track.is_activated]
  18. # 输出每一个ID的特征长度,即特征的个数,主要是如果某帧里面出现了某个ID,就会将其他特征append到deque中,当超过长度阈值的时候,就会把先进deque中的特征删除,将新特征加到队尾(先进先出)。
  19. print("detections feature : {}".format([len(i.features) for i in output_stracks]))
  20. # 在STrack中还包含更新特征
  21. class STrack(BaseTrack):
  22. self.buffer_size = int(frame_rate / 30.0 * opt.track_buffer)
  23. def update_features(self, feat):
  24. feat /= np.linalg.norm(feat)
  25. self.curr_feat = feats
  26. if self.smooth_feat is None:
  27. self.smooth_feat = feat
  28. else:
  29. self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
  30. self.features.append(feat)
  31. self.smooth_feat /= np.linalg.norm(self.smooth_feat)

2.collections deque进出demo

下面是写的一个测试小demo。由于feature维度大了,所以直接打印输出也不太好看清,代码中是怎么对存在队列中的特征怎么处理的。所以写了这样一个demo。目的是当append的元素个数超出初始化队列范围,怎么处理的?

  1. # deque_test.py
  2. from collections import deque
  3. def main():
  4. features = deque([], maxlen=5)
  5. for i in range(10):
  6. features.append(i)
  7. print(features)
  8. main()

运行显示:

20200710130702662.png

3.总结

根据1.2,可以发现,当每个ID特征个数超过buffer_size时,会用新特征代替之前的特征,放到deque中。

说一下里面的几种情况吧:

(1) 当激活的Track,一直处于激活态,那么deque中对应ID的特征会一直更新,未超出buffer_size时,就不断加1,当超出时,就删除最早的特征,添加当前ID特征;

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM1OTc1NDQ3_size_16_color_FFFFFF_t_70

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM1OTc1NDQ3_size_16_color_FFFFFF_t_70 1

注:这两帧完美诠释了是否超出buffer_size的操作。ID=6,7未超过buffer_size,特征个数加1,ID=5超出了buffer_size,所以特征个数不变,但是特征都在更新。

(2)当激活的Track丢失后,会将其ID特征放到lost_stracks里,如果refind了话,会将lost_stracks中的特征在放到activated_stracks中。

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM1OTc1NDQ3_size_16_color_FFFFFF_t_70 2

watermark_type_ZmFuZ3poZW5naGVpdGk_shadow_10_text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzM1OTc1NDQ3_size_16_color_FFFFFF_t_70 3

注:由上图可以看到Frame588帧时,ID=19丢失,detection feature也只有4个ID的了,Frame589找回了ID19,detection feature有5个ID的了。

之前有人私信我问这部分的的注释代码,现在将其贴出来:

FairMOT/src/lib/multitracker.py目录下的函数:

  1. def update(self, im_blob, img0):
  2. self.frame_id += 1
  3. print('================Frame {}==============='.format(self.frame_id))
  4. activated_stracks = []
  5. refind_stracks = []
  6. lost_stracks = []
  7. removed_stracks = []
  8. width = img0.shape[1]
  9. height = img0.shape[0]
  10. inp_height = im_blob.shape[2]
  11. inp_width = im_blob.shape[3]
  12. c = np.array([width / 2., height / 2.], dtype=np.float32)
  13. s = max(float(inp_width) / float(inp_height) * height, width) * 1.0
  14. meta = {'c': c, 's': s,
  15. 'out_height': inp_height // self.opt.down_ratio,
  16. 'out_width': inp_width // self.opt.down_ratio}
  17. ''' Step 1: Network forward, get detections & embeddings'''
  18. with torch.no_grad():
  19. output = self.model(im_blob)[-1]
  20. # heatmap and width/height ReID_feature
  21. hm = output['hm'].sigmoid_()
  22. wh = output['wh']
  23. id_feature = output['id']
  24. id_feature = F.normalize(id_feature, dim=1)
  25. reg = output['reg'] if self.opt.reg_offset else None
  26. # decode by heatmap and width/height and get coordinate
  27. dets, inds = mot_decode(hm, wh, reg=reg, cat_spec_wh=self.opt.cat_spec_wh, K=self.opt.K)
  28. id_feature = _tranpose_and_gather_feat(id_feature, inds)
  29. id_feature = id_feature.squeeze(0)
  30. id_feature = id_feature.cpu().numpy()
  31. dets = self.post_process(dets, meta)
  32. dets = self.merge_outputs([dets])[1]
  33. # filter the dets which score is lower than self.opt.conf_thres
  34. remain_inds = dets[:, 4] > self.opt.conf_thres
  35. dets = dets[remain_inds]
  36. id_feature = id_feature[remain_inds]
  37. print("id_feature shape : {}".format(id_feature.shape))
  38. # vis
  39. '''
  40. for i in range(0, dets.shape[0]):
  41. bbox = dets[i][0:4]
  42. cv2.rectangle(img0, (bbox[0], bbox[1]),
  43. (bbox[2], bbox[3]),
  44. (0, 255, 0), 2)
  45. cv2.imshow('dets', img0)
  46. cv2.waitKey(0)
  47. id0 = id0-1
  48. '''
  49. # building connections between detection and id feature
  50. if len(dets) > 0:
  51. '''Detections'''
  52. detections = [STrack(STrack.tlbr_to_tlwh(tlbrs[:4]), tlbrs[4], f, 30) for
  53. (tlbrs, f) in zip(dets[:, :5], id_feature)]
  54. else:
  55. detections = []
  56. ''' Add newly detected tracklets to tracked_stracks'''
  57. unconfirmed = []
  58. tracked_stracks = [] # type: list[STrack]
  59. for track in self.tracked_stracks:
  60. if not track.is_activated:
  61. unconfirmed.append(track)
  62. else:
  63. tracked_stracks.append(track)
  64. ''' Step 2: First association, with embedding'''
  65. strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
  66. # Predict the current location with KF
  67. #for strack in strack_pool:
  68. #strack.predict()
  69. STrack.multi_predict(strack_pool)
  70. dists = matching.embedding_distance(strack_pool, detections)
  71. #dists = matching.gate_cost_matrix(self.kalman_filter, dists, strack_pool, detections)
  72. # 运动估计
  73. dists = matching.fuse_motion(self.kalman_filter, dists, strack_pool, detections)
  74. matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.7)
  75. for itracked, idet in matches:
  76. track = strack_pool[itracked]
  77. det = detections[idet]
  78. if track.state == TrackState.Tracked:
  79. track.update(detections[idet], self.frame_id)
  80. activated_stracks.append(track)
  81. print('Activated track: {}'.format([track for track in activated_stracks]))
  82. print('Activated0: {}'.format([track.track_id for track in activated_stracks]))
  83. else:
  84. track.re_activate(det, self.frame_id, new_id=False)
  85. refind_stracks.append(track)
  86. ''' Step 3: Second association, with IOU'''
  87. detections = [detections[i] for i in u_detection]
  88. r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
  89. dists = matching.iou_distance(r_tracked_stracks, detections)
  90. # matches for Detection and Track match,
  91. # u_track for track can't find detection which is in current picture,
  92. # u_detection for detection can't find track at before track list,
  93. matches, u_track, u_detection = matching.linear_assignment(dists, thresh=0.5)
  94. for itracked, idet in matches:
  95. track = r_tracked_stracks[itracked]
  96. det = detections[idet]
  97. if track.state == TrackState.Tracked:
  98. track.update(det, self.frame_id)
  99. activated_stracks.append(track)
  100. print('Activated1: {}'.format([track for track in activated_stracks]))
  101. else:
  102. track.re_activate(det, self.frame_id, new_id=False)
  103. refind_stracks.append(track)
  104. for it in u_track:
  105. track = r_tracked_stracks[it]
  106. if not track.state == TrackState.Lost:
  107. track.mark_lost()
  108. lost_stracks.append(track)
  109. print('Lost lost_stracks: {}'.format([track for track in lost_stracks]))
  110. print('Lost1: {}'.format([track.track_id for track in lost_stracks]))
  111. '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
  112. detections = [detections[i] for i in u_detection]
  113. dists = matching.iou_distance(unconfirmed, detections)
  114. print('u_detection : {}'.format([i for i in u_detection]))
  115. print('unconfirmed : {}'.format([i for i in unconfirmed]))
  116. matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
  117. for itracked, idet in matches:
  118. unconfirmed[itracked].update(detections[idet], self.frame_id)
  119. activated_stracks.append(unconfirmed[itracked])
  120. print('Activated2: {}'.format([track.track_id for track in activated_stracks]))
  121. for it in u_unconfirmed:
  122. track = unconfirmed[it]
  123. track.mark_removed()
  124. removed_stracks.append(track)
  125. print('Removed1: {}'.format([track.track_id for track in removed_stracks]))
  126. """ Step 4: Init new stracks"""
  127. for inew in u_detection:
  128. track = detections[inew]
  129. if track.score < self.det_thresh:
  130. continue
  131. track.activate(self.kalman_filter, self.frame_id)
  132. activated_stracks.append(track)
  133. print('Activated3: {}'.format([track.track_id for track in activated_stracks]))
  134. """ Step 5: Update state"""
  135. for track in self.lost_stracks:
  136. if self.frame_id - track.end_frame > self.max_time_lost:
  137. print("self.buffer_size : "+str(self.max_time_lost))
  138. track.mark_removed()
  139. removed_stracks.append(track)
  140. print('Removed2: {}'.format([track.track_id for track in removed_stracks]))
  141. # print('Ramained match {} s'.format(t4-t3))
  142. self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
  143. self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_stracks)
  144. self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
  145. self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
  146. print('Lost2: {}'.format([track.track_id for track in lost_stracks]))
  147. self.lost_stracks.extend(lost_stracks)
  148. self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
  149. print('Lost3: {}'.format([track.track_id for track in lost_stracks]))
  150. self.removed_stracks.extend(removed_stracks)
  151. self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
  152. # get scores of lost tracks
  153. output_stracks = [track for track in self.tracked_stracks if track.is_activated]
  154. print("detections feature : {}".format([len(i.features) for i in output_stracks]))
  155. # print('================Frame {}==============='.format(self.frame_id))
  156. print('Activated: {}'.format([track.track_id for track in activated_stracks]))
  157. print('Refind: {}'.format([track.track_id for track in refind_stracks]))
  158. print('Lost: {}'.format([track.track_id for track in lost_stracks]))
  159. print('Removed: {}'.format([track.track_id for track in removed_stracks]))
  160. print('output: {}'.format([track.track_id for track in output_stracks]))
  161. logger.debug('===========Frame {}=========='.format(self.frame_id))
  162. logger.debug('Activated: {}'.format([track.track_id for track in activated_stracks]))
  163. logger.debug('Refind: {}'.format([track.track_id for track in refind_stracks]))
  164. logger.debug('Lost: {}'.format([track.track_id for track in lost_stracks]))
  165. logger.debug('Removed: {}'.format([track.track_id for track in removed_stracks]))
  166. return output_stracks

发表评论

表情:
评论列表 (有 0 条评论,6人围观)

还没有评论,来说两句吧...

相关阅读