Skip to content

Commit

Permalink
Improvements to groundtruth management
Browse files Browse the repository at this point in the history
  • Loading branch information
luigifreda committed Dec 26, 2024
1 parent 97a8f06 commit 1b564eb
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 29 deletions.
77 changes: 51 additions & 26 deletions io/ground_truth.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,22 +439,34 @@ def associate(first_list, second_list, offset=0, max_difference=0.025*(10**9)):


class EurocGroundTruth(GroundTruth):
kReadTumConversion = False
def __init__(self, path, name, associations=None, start_frame_id=0, type = GroundTruthType.EUROC):
super().__init__(path, name, associations, start_frame_id, type)
self.scale = kScaleEuroc
self.filename = path + '/' + name + '/mav0/state_groundtruth_estimate0/data.tum' # NOTE: Use the script io/generate_euroc_groundtruths_as_tum.sh to generate these groundtruth files

if EurocGroundTruth.kReadTumConversion:
# NOTE: Use the script io/generate_euroc_groundtruths_as_tum.sh to generate these groundtruth files
self.filename = path + '/' + name + '/mav0/state_groundtruth_estimate0/data.tum'
else:
# Use the original Euroc groundtruth file
self.filename = path + '/' + name + '/mav0/state_groundtruth_estimate0/data.csv'

base_path = os.path.dirname(self.filename)
print('base_path: ', base_path)

if not os.path.isfile(self.filename):
error_message = f'ERROR: Groundtruth file not found: {self.filename}. Use the script io/generate_euroc_groundtruths_as_tum.sh to generate these groundtruth files!'
if EurocGroundTruth.kReadTumConversion:
error_message = f'ERROR: Groundtruth file not found: {self.filename}. Use the script io/generate_euroc_groundtruths_as_tum.sh to generate these groundtruth files!'
else:
error_message = f'ERROR: Groundtruth file not found: {self.filename}. Please, check how you deployed the files and if the code is consistent with this!'
Printer.red(error_message)
sys.exit(error_message)

with open(self.filename) as f:
self.data = f.readlines()
self.data = [line.strip().split() for line in self.data]
if EurocGroundTruth.kReadTumConversion:
with open(self.filename) as f:
self.data = f.readlines()
self.data = [line.strip().split() for line in self.data]
else:
self.data = self.read_gt_data_state(self.filename)

if len(self.data) > 0:
self.found = True
Expand All @@ -477,24 +489,6 @@ def __init__(self, path, name, associations=None, start_frame_id=0, type = Groun
with open(associations_file, 'r') as f:
data = json.load(f)
self.association_matches = {int(k): v for k, v in data.items()}

# def read_gt_data(self, csv_file):
# data = []
# # check csv_file exists
# if not os.path.isfile(csv_file):
# Printer.red(f'Groundtruth file not found: {csv_file}')
# return []
# with open(csv_file, 'r') as f:
# reader = csv.reader(f)
# header = next(reader) # Skip header row
# for row in reader:
# timestamp_ns = int(row[0])
# x = row[1]
# y = row[2]
# z = row[3]
# timestamp_s = (timestamp_ns / 1000000000)
# data.append((timestamp_s, (x,y,z)))
# return data

def read_image_data(self, csv_file):
timestamps_and_filenames = []
Expand All @@ -504,10 +498,42 @@ def read_image_data(self, csv_file):
for row in reader:
timestamp_ns = int(row[0])
filename = row[1]
timestamp_s = (float(timestamp_ns) / 1000000000)
timestamp_s = (float(timestamp_ns) * 1e-9)
timestamps_and_filenames.append((timestamp_s, filename))
return timestamps_and_filenames

def read_gt_data_state(self, csv_file):
data = []
with open(csv_file, 'r') as f:
for line in f:
if line[0] == '#':
continue
parts = line.strip().split(',')
timestamp_ns = int(parts[0])
position = np.array([float(parts[1]), float(parts[2]), float(parts[3])])
quaternion = np.array([float(parts[4]), float(parts[5]), float(parts[6]), float(parts[7])]) # qw, qx, qy, qz
# velocity = np.array([float(parts[8]), float(parts[9]), float(parts[10])])
# accel_bias = np.array([float(parts[11]), float(parts[12]), float(parts[13])])
# gyro_bias = np.array([float(parts[14]), float(parts[15]), float(parts[16])])
# we expect the quaternion in the form [qx, qy, qz, qw] as in the TUM format
data.append((float(timestamp_ns)*1e-9, position[0], position[1], position[2], quaternion[1], quaternion[2], quaternion[3], quaternion[0]))
return data

def read_gt_data_pose(self, csv_file):
data = []
with open(csv_file, 'r') as f:
for line in f:
if line[0] == '#':
continue
parts = line.strip().split(',')
timestamp_ns = int(parts[0])
position = np.array([float(parts[1]), float(parts[2]), float(parts[3])])
quaternion = np.array([float(parts[4]), float(parts[5]), float(parts[6]), float(parts[7])])
# we expect the quaternion in the form [qx, qy, qz, qw] as in the TUM format
data.append((float(timestamp_ns)*1e-9, position[0], position[1], position[2], quaternion[1], quaternion[2], quaternion[3], quaternion[0]))
return data


@staticmethod
def associate(first_list, second_list, offset=0, max_difference=0.025*(10**9)):
"""
Expand Down Expand Up @@ -571,7 +597,6 @@ def getTimestampPositionAndAbsoluteScale(self, frame_id):
# from https://www.researchgate.net/profile/Michael-Burri/publication/291954561_The_EuRoC_micro_aerial_vehicle_datasets/links/56af0c6008ae19a38516937c/The-EuRoC-micro-aerial-vehicle-datasets.pdf
return timestamp, x,y,z, abs_scale


# return timestamp, x,y,z, qx,qy,qz,qw, scale
def getTimestampPoseAndAbsoluteScale(self, frame_id):
frame_id+=self.start_frame_id
Expand Down
19 changes: 16 additions & 3 deletions utilities/utils_geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,12 @@ def __init__(self, timestamps_associations=[], estimated_t_w_i=[], gt_t_w_i=[],
# - filter_t_w_i [Nx3]
# - gt_timestamps [Nx1]
# - gt_t_w_i [Nx3]
# - align_est_associations: if True, align the estimated trajectory with the gt
# - max_align_dt: maximum time difference between filter and gt timestamps in seconds
# - find_scale allows to compute the full Sim(3) transformation in case the scale is unknown
def align_trajs_with_svd(filter_timestamps, filter_t_w_i, gt_timestamps, gt_t_w_i, align_gt=True, compute_align_error=True, find_scale=False, align_est_associations=True, verbose=False):
def align_trajs_with_svd(filter_timestamps, filter_t_w_i, gt_timestamps, gt_t_w_i, align_gt=True, \
compute_align_error=True, find_scale=False, align_est_associations=True, max_align_dt=1e-1, \
verbose=False):
est_associations = []
gt_associations = []
timestamps_associations = []
Expand All @@ -482,6 +486,8 @@ def align_trajs_with_svd(filter_timestamps, filter_t_w_i, gt_timestamps, gt_t_w_
print(f'gt_t_w_i: {gt_t_w_i.shape}')
print(f'filter_timestamps: {filter_timestamps}')
print(f'gt_timestamps: {gt_timestamps}')

max_dt = 0

for i in range(len(filter_t_w_i)):
timestamp = filter_timestamps[i]
Expand All @@ -498,14 +504,18 @@ def align_trajs_with_svd(filter_timestamps, filter_t_w_i, gt_timestamps, gt_t_w_

dt = timestamp - gt_timestamps[j]
dt_gt = gt_timestamps[j + 1] - gt_timestamps[j]

abs_dt = abs(dt)

assert dt >= 0, f"dt {dt}"
assert dt_gt > 0, f"dt_gt {dt_gt}"

# Skip if the interval between gt is larger than 100ms
# if dt_gt > 1.1e8:
# continue
if abs_dt > max_align_dt:
continue

max_dt = max(max_dt, abs_dt)

ratio = dt / dt_gt

assert 0 <= ratio < 1
Expand All @@ -516,6 +526,9 @@ def align_trajs_with_svd(filter_timestamps, filter_t_w_i, gt_timestamps, gt_t_w_
timestamps_associations.append(timestamp)
gt_associations.append(gt)
est_associations.append(filter_t_w_i[i])

if verbose:
print(f'max align dt: {max_dt}')

num_samples = len(est_associations)
if verbose:
Expand Down

0 comments on commit 1b564eb

Please sign in to comment.