Skip to content

Commit b595285

Browse files
added sort to kalman
1 parent d574142 commit b595285

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

src/object_spatial_tools_ros/robot_kf_undirected_object_tracker.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,17 @@ class SingleKFUndirectedObjectTracker(object):
2323
R_diag - diagonale of R matrix [Rx, Ry]
2424
k_decay - coefficien of speed reduction
2525
color - (r, g, b) each normlized on 1
26+
score - initial score for object
2627
'''
27-
def __init__(self, x_start, t_start, Q_diag, R_diag, k_decay, color):
28+
def __init__(self, x_start, t_start, Q_diag, R_diag, k_decay, color, score):
2829

2930
self.x = np.array([x_start[0], x_start[1], 0.0, 0.0])
3031
self.last_t = t_start
3132
self.last_upd_t = t_start
3233
self.Q = np.diag(Q_diag)
3334
self.R = np.diag(R_diag)
3435
self.k_decay = k_decay
36+
self.score = score
3537

3638
self.color = color
3739

@@ -68,12 +70,15 @@ def predict(self, t):
6870

6971
self.predict_steps += 1
7072

73+
self.score *= self.k_decay
74+
7175

7276
'''
7377
z - measured x, y values
7478
t - time stamp for update, seconds
79+
score - score of new measurment
7580
'''
76-
def update(self, z, t):
81+
def update(self, z, t, score):
7782
self.last_upd_t = t
7883

7984
y = z - np.matmul(self.H, self.x)
@@ -89,6 +94,8 @@ def update(self, z, t):
8994
self.track.append(self.x.copy())
9095

9196
self.predict_steps = 0
97+
98+
self.score = score
9299

93100
class RobotKFUndirectedObjectTracker(object):
94101

@@ -125,6 +132,8 @@ def __init__(self):
125132
self.min_score = rospy.get_param('~min_score', 0)
126133
self.min_score_soft = rospy.get_param('~min_score_soft', self.min_score)
127134

135+
self.sort_by_score = rospy.get_param("~sort_by_score", False)
136+
128137
self.colors = []
129138
for c in plt.rcParams['axes.prop_cycle'].by_key()['color']:
130139
#print(type(mpl.colors.to_rgb(c)))
@@ -157,7 +166,7 @@ def process(self, event):
157166

158167
#rospy.logwarn(f"{name} {i} {kf.x} {kf.P}")
159168
for index in sorted(remove_index, reverse=True):
160-
del kfs[index]
169+
del kfs[index]
161170

162171
self.to_marker_array()
163172
self.to_tf()
@@ -166,7 +175,15 @@ def process(self, event):
166175

167176
def to_tf(self):
168177
now = rospy.Time.now()
169-
for name, kfs in self.objects_to_KFs.items():
178+
179+
if self.sort_by_score:
180+
mass = sorted(list(self.objects_to_KFs.items()), key=lambda x: x[1].score)
181+
else:
182+
mass = list(self.objects_to_KFs.items())
183+
184+
#for name, kfs in self.objects_to_KFs.items():
185+
for name, kfs in mass:
186+
170187
for i, kf in enumerate(kfs):
171188

172189
t = TransformStamped()
@@ -188,7 +205,13 @@ def to_tracked_object_array(self):
188205
msg_array = TrackedObjectArray()
189206
msg_array.header.stamp = rospy.Time.now()
190207
msg_array.header.frame_id = self.target_frame
191-
for name, kfs in self.objects_to_KFs.items():
208+
209+
if self.sort_by_score:
210+
mass = sorted(list(self.objects_to_KFs.items()), key=lambda x: x[1].score)
211+
else:
212+
mass = list(self.objects_to_KFs.items())
213+
for name, kfs in mass:
214+
#for name, kfs in self.objects_to_KFs.items():
192215
for i, kf in enumerate(kfs):
193216
msg = TrackedObject()
194217
msg.child_frame_id = self.tf_pub_prefix+name+f'_{i}'
@@ -220,7 +243,12 @@ def to_tracked_object_array(self):
220243
def to_marker_array(self):
221244
now = rospy.Time.now()
222245
marker_array = MarkerArray()
223-
for name, kfs in self.objects_to_KFs.items():
246+
if self.sort_by_score:
247+
mass = sorted(list(self.objects_to_KFs.items()), key=lambda x: x[1].score)
248+
else:
249+
mass = list(self.objects_to_KFs.items())
250+
for name, kfs in mass:
251+
#for name, kfs in self.objects_to_KFs.items():
224252
i = -1
225253
for i, kf in enumerate(kfs):
226254
# TEXT

0 commit comments

Comments
 (0)