Skip to content

Commit 4a20585

Browse files
author
Dan Jia
committed
clean up plotting scripts
1 parent 19d0177 commit 4a20585

File tree

4 files changed

+378
-13
lines changed

4 files changed

+378
-13
lines changed

dr_spaam/bin/analyze_pseudo_labels.py renamed to dr_spaam/bin/plotting/analyze_pseudo_labels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import matplotlib.pyplot as plt
99
from matplotlib.gridspec import GridSpec
1010

11-
from dr_spaam.dataset.get_dataloader import get_dataloader
11+
from dr_spaam.dataset import get_dataloader
1212
import dr_spaam.utils.jrdb_transforms as jt
1313
import dr_spaam.utils.precision_recall as pru
1414
import dr_spaam.utils.utils as u
@@ -254,7 +254,7 @@ def _write_file_make_dir(f_name, f_str):
254254

255255

256256
def generate_pseudo_labels():
257-
with open("./base_dr_spaam_jrdb_cfg.yaml", "r") as f:
257+
with open("./cfgs/base_dr_spaam_jrdb_cfg.yaml", "r") as f:
258258
cfg = yaml.safe_load(f)
259259
cfg["dataset"]["pseudo_label"] = True
260260
cfg["dataset"]["pl_correction_level"] = 0
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pickle
2+
import numpy as np
3+
4+
p1 = "/home/jia/git/awesome_repos/2D_lidar_person_detection/dr_spaam/logs/20210520_224024_drow_jrdb_EVAL/output/val/e000000/evaluation/all/result_r05.pkl"
5+
p2 = "/home/jia/git/awesome_repos/2D_lidar_person_detection/dr_spaam/logs/20210520_231344_drow_jrdb_EVAL/output/val/e000000/evaluation/all/result_r05.pkl"
6+
7+
for p in (p1, p2):
8+
with open(p, "rb") as f:
9+
res = pickle.load(f)
10+
11+
eer = res["eer"]
12+
arg = np.argmin(np.abs(res["precisions"] - eer))
13+
print(res["thresholds"][arg], " ", res["precisions"][arg], " ", res["recalls"][arg])

dr_spaam/bin/get_pseudo_label_videos.py renamed to dr_spaam/bin/plotting/get_pseudo_label_videos.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import torch
99

10-
from dr_spaam.dataset.get_dataloader import get_dataloader
10+
from dr_spaam.dataset import get_dataloader
1111
import dr_spaam.utils.jrdb_transforms as jt
1212
import dr_spaam.utils.utils as u
1313

@@ -18,8 +18,9 @@
1818
# _Y_LIM = (-10, 4)
1919
_Y_LIM = (-7, 7)
2020

21-
_PLOTTING_INTERVAL = 2
21+
_PLOTTING_INTERVAL = 20
2222
_MAX_COUNT = 1e9
23+
# _MAX_COUNT = 1e1
2324

2425
# _COLOR_CLOSE_HSV = (1.0, 0.59, 0.75)
2526
_COLOR_CLOSE_HSV = (0.0, 1.0, 1.0)
@@ -125,7 +126,7 @@ def _plot_frame_im(batch_dict, ib):
125126
plt.close(fig)
126127

127128

128-
def _plot_frame_pts(batch_dict, ib, pred_cls, pred_reg):
129+
def _plot_frame_pts(batch_dict, ib, pred_cls, pred_reg, pred_cls_p, pred_reg_p):
129130
frame_id = f"{batch_dict['frame_id'][ib]:06d}"
130131
sequence = batch_dict["sequence"][ib]
131132

@@ -159,17 +160,44 @@ def _plot_frame_pts(batch_dict, ib, pred_cls, pred_reg):
159160
)
160161
ax.add_artist(c)
161162

162-
# inference result or pseudo-labels
163+
# plot detections
163164
if pred_cls is not None and pred_reg is not None:
164165
dets_xy, dets_cls, _ = u.nms_predicted_center(
165166
scan_r, scan_phi, pred_cls[ib].reshape(-1), pred_reg[ib]
166167
)
167-
dets_xy = dets_xy[dets_cls > 0.15]
168+
dets_xy = dets_xy[dets_cls >= 0.9438938] # at EER
168169
if len(dets_xy) > 0:
169170
for x, y in dets_xy:
170-
c = plt.Circle((x, y), radius=0.4, color=(0.0, 0.56, 0.56), fill=False)
171+
c = plt.Circle((x, y), radius=0.4, color=(0, 0.56, 0.56), fill=False)
171172
ax.add_artist(c)
172-
fig_file = os.path.join(_SAVE_DIR, f"figs/{sequence}/scan_det_{frame_id}.png")
173+
fig_file = os.path.join(
174+
_SAVE_DIR, f"figs/{sequence}/scan_det_{frame_id}.png"
175+
)
176+
177+
# plot in addition detections from a pre-trained
178+
if pred_cls_p is not None and pred_reg_p is not None:
179+
dets_xy, dets_cls, _ = u.nms_predicted_center(
180+
scan_r, scan_phi, pred_cls_p[ib].reshape(-1), pred_reg_p[ib]
181+
)
182+
dets_xy = dets_xy[dets_cls > 0.29919282] # at EER
183+
if len(dets_xy) > 0:
184+
for x, y in dets_xy:
185+
c = plt.Circle((x, y), radius=0.4, color="green", fill=False)
186+
ax.add_artist(c)
187+
# plot pre-trained detections only
188+
elif pred_cls_p is not None and pred_reg_p is not None:
189+
dets_xy, dets_cls, _ = u.nms_predicted_center(
190+
scan_r, scan_phi, pred_cls_p[ib].reshape(-1), pred_reg_p[ib]
191+
)
192+
dets_xy = dets_xy[dets_cls > 0.29919282] # at EER
193+
if len(dets_xy) > 0:
194+
for x, y in dets_xy:
195+
c = plt.Circle((x, y), radius=0.4, color="green", fill=False)
196+
ax.add_artist(c)
197+
fig_file = os.path.join(
198+
_SAVE_DIR, f"figs/{sequence}/scan_pretrain_{frame_id}.png"
199+
)
200+
# plot pseudo-labels only
173201
else:
174202
pl_xy = batch_dict["pseudo_label_loc_xy"][ib]
175203
if len(pl_xy) > 0:
@@ -185,7 +213,7 @@ def _plot_frame_pts(batch_dict, ib, pred_cls, pred_reg):
185213

186214

187215
def plot_pseudo_label_for_all_frames():
188-
with open("./base_drow_jrdb_cfg.yaml", "r") as f:
216+
with open("./cfgs/base_drow_jrdb_cfg.yaml", "r") as f:
189217
cfg = yaml.safe_load(f)
190218
cfg["dataset"]["pseudo_label"] = True
191219
cfg["dataset"]["pl_correction_level"] = 0
@@ -203,7 +231,12 @@ def plot_pseudo_label_for_all_frames():
203231
model.eval()
204232

205233
logger = Logger(cfg["pipeline"]["Logger"])
206-
logger.load_ckpt("./ckpts/ckpt_phce_drow_e40.pth", model)
234+
logger.load_ckpt("./ckpts/ckpt_jrdb_pl_drow3_phce_e40.pth", model)
235+
236+
model_pretrain = get_model(cfg["model"])
237+
model_pretrain.cuda()
238+
model_pretrain.eval()
239+
logger.load_ckpt("./ckpts/ckpt_drow_drow3_e40.pth", model_pretrain)
207240

208241
# generate pseudo labels for all sample
209242
for count, batch_dict in enumerate(tqdm(test_loader)):
@@ -216,11 +249,19 @@ def plot_pseudo_label_for_all_frames():
216249
pred_cls = torch.sigmoid(pred_cls).data.cpu().numpy()
217250
pred_reg = pred_reg.data.cpu().numpy()
218251

252+
pred_cls_p, pred_reg_p = model_pretrain(net_input)
253+
pred_cls_p = torch.sigmoid(pred_cls_p).data.cpu().numpy()
254+
pred_reg_p = pred_reg_p.data.cpu().numpy()
255+
219256
if count % _PLOTTING_INTERVAL == 0:
220257
for ib in range(len(batch_dict["input"])):
221258
_plot_frame_im(batch_dict, ib)
222-
_plot_frame_pts(batch_dict, ib, None, None)
223-
_plot_frame_pts(batch_dict, ib, pred_cls, pred_reg)
259+
# _plot_frame_pts(batch_dict, ib, None, None, None, None)
260+
# _plot_frame_pts(batch_dict, ib, pred_cls, pred_reg, None, None)
261+
# # _plot_frame_pts(batch_dict, ib, pred_cls, pred_reg, pred_cls_p, pred_reg_p)
262+
263+
_plot_frame_pts(batch_dict, ib, pred_cls, pred_reg, None, None)
264+
_plot_frame_pts(batch_dict, ib, None, None, pred_cls_p, pred_reg_p)
224265

225266

226267
def plot_color_bar():

0 commit comments

Comments
 (0)