7
7
8
8
import torch
9
9
10
- from dr_spaam .dataset . get_dataloader import get_dataloader
10
+ from dr_spaam .dataset import get_dataloader
11
11
import dr_spaam .utils .jrdb_transforms as jt
12
12
import dr_spaam .utils .utils as u
13
13
18
18
# _Y_LIM = (-10, 4)
19
19
_Y_LIM = (- 7 , 7 )
20
20
21
- _PLOTTING_INTERVAL = 2
21
+ _PLOTTING_INTERVAL = 20
22
22
_MAX_COUNT = 1e9
23
+ # _MAX_COUNT = 1e1
23
24
24
25
# _COLOR_CLOSE_HSV = (1.0, 0.59, 0.75)
25
26
_COLOR_CLOSE_HSV = (0.0 , 1.0 , 1.0 )
@@ -125,7 +126,7 @@ def _plot_frame_im(batch_dict, ib):
125
126
plt .close (fig )
126
127
127
128
128
- def _plot_frame_pts (batch_dict , ib , pred_cls , pred_reg ):
129
+ def _plot_frame_pts (batch_dict , ib , pred_cls , pred_reg , pred_cls_p , pred_reg_p ):
129
130
frame_id = f"{ batch_dict ['frame_id' ][ib ]:06d} "
130
131
sequence = batch_dict ["sequence" ][ib ]
131
132
@@ -159,17 +160,44 @@ def _plot_frame_pts(batch_dict, ib, pred_cls, pred_reg):
159
160
)
160
161
ax .add_artist (c )
161
162
162
- # inference result or pseudo-labels
163
+ # plot detections
163
164
if pred_cls is not None and pred_reg is not None :
164
165
dets_xy , dets_cls , _ = u .nms_predicted_center (
165
166
scan_r , scan_phi , pred_cls [ib ].reshape (- 1 ), pred_reg [ib ]
166
167
)
167
- dets_xy = dets_xy [dets_cls > 0.15 ]
168
+ dets_xy = dets_xy [dets_cls >= 0.9438938 ] # at EER
168
169
if len (dets_xy ) > 0 :
169
170
for x , y in dets_xy :
170
- c = plt .Circle ((x , y ), radius = 0.4 , color = (0.0 , 0.56 , 0.56 ), fill = False )
171
+ c = plt .Circle ((x , y ), radius = 0.4 , color = (0 , 0.56 , 0.56 ), fill = False )
171
172
ax .add_artist (c )
172
- fig_file = os .path .join (_SAVE_DIR , f"figs/{ sequence } /scan_det_{ frame_id } .png" )
173
+ fig_file = os .path .join (
174
+ _SAVE_DIR , f"figs/{ sequence } /scan_det_{ frame_id } .png"
175
+ )
176
+
177
+ # plot in addition detections from a pre-trained
178
+ if pred_cls_p is not None and pred_reg_p is not None :
179
+ dets_xy , dets_cls , _ = u .nms_predicted_center (
180
+ scan_r , scan_phi , pred_cls_p [ib ].reshape (- 1 ), pred_reg_p [ib ]
181
+ )
182
+ dets_xy = dets_xy [dets_cls > 0.29919282 ] # at EER
183
+ if len (dets_xy ) > 0 :
184
+ for x , y in dets_xy :
185
+ c = plt .Circle ((x , y ), radius = 0.4 , color = "green" , fill = False )
186
+ ax .add_artist (c )
187
+ # plot pre-trained detections only
188
+ elif pred_cls_p is not None and pred_reg_p is not None :
189
+ dets_xy , dets_cls , _ = u .nms_predicted_center (
190
+ scan_r , scan_phi , pred_cls_p [ib ].reshape (- 1 ), pred_reg_p [ib ]
191
+ )
192
+ dets_xy = dets_xy [dets_cls > 0.29919282 ] # at EER
193
+ if len (dets_xy ) > 0 :
194
+ for x , y in dets_xy :
195
+ c = plt .Circle ((x , y ), radius = 0.4 , color = "green" , fill = False )
196
+ ax .add_artist (c )
197
+ fig_file = os .path .join (
198
+ _SAVE_DIR , f"figs/{ sequence } /scan_pretrain_{ frame_id } .png"
199
+ )
200
+ # plot pseudo-labels only
173
201
else :
174
202
pl_xy = batch_dict ["pseudo_label_loc_xy" ][ib ]
175
203
if len (pl_xy ) > 0 :
@@ -185,7 +213,7 @@ def _plot_frame_pts(batch_dict, ib, pred_cls, pred_reg):
185
213
186
214
187
215
def plot_pseudo_label_for_all_frames ():
188
- with open ("./base_drow_jrdb_cfg.yaml" , "r" ) as f :
216
+ with open ("./cfgs/ base_drow_jrdb_cfg.yaml" , "r" ) as f :
189
217
cfg = yaml .safe_load (f )
190
218
cfg ["dataset" ]["pseudo_label" ] = True
191
219
cfg ["dataset" ]["pl_correction_level" ] = 0
@@ -203,7 +231,12 @@ def plot_pseudo_label_for_all_frames():
203
231
model .eval ()
204
232
205
233
logger = Logger (cfg ["pipeline" ]["Logger" ])
206
- logger .load_ckpt ("./ckpts/ckpt_phce_drow_e40.pth" , model )
234
+ logger .load_ckpt ("./ckpts/ckpt_jrdb_pl_drow3_phce_e40.pth" , model )
235
+
236
+ model_pretrain = get_model (cfg ["model" ])
237
+ model_pretrain .cuda ()
238
+ model_pretrain .eval ()
239
+ logger .load_ckpt ("./ckpts/ckpt_drow_drow3_e40.pth" , model_pretrain )
207
240
208
241
# generate pseudo labels for all sample
209
242
for count , batch_dict in enumerate (tqdm (test_loader )):
@@ -216,11 +249,19 @@ def plot_pseudo_label_for_all_frames():
216
249
pred_cls = torch .sigmoid (pred_cls ).data .cpu ().numpy ()
217
250
pred_reg = pred_reg .data .cpu ().numpy ()
218
251
252
+ pred_cls_p , pred_reg_p = model_pretrain (net_input )
253
+ pred_cls_p = torch .sigmoid (pred_cls_p ).data .cpu ().numpy ()
254
+ pred_reg_p = pred_reg_p .data .cpu ().numpy ()
255
+
219
256
if count % _PLOTTING_INTERVAL == 0 :
220
257
for ib in range (len (batch_dict ["input" ])):
221
258
_plot_frame_im (batch_dict , ib )
222
- _plot_frame_pts (batch_dict , ib , None , None )
223
- _plot_frame_pts (batch_dict , ib , pred_cls , pred_reg )
259
+ # _plot_frame_pts(batch_dict, ib, None, None, None, None)
260
+ # _plot_frame_pts(batch_dict, ib, pred_cls, pred_reg, None, None)
261
+ # # _plot_frame_pts(batch_dict, ib, pred_cls, pred_reg, pred_cls_p, pred_reg_p)
262
+
263
+ _plot_frame_pts (batch_dict , ib , pred_cls , pred_reg , None , None )
264
+ _plot_frame_pts (batch_dict , ib , None , None , pred_cls_p , pred_reg_p )
224
265
225
266
226
267
def plot_color_bar ():
0 commit comments