Skip to content

Commit

Permalink
Tool: update data extract:
Browse files Browse the repository at this point in the history
(1) encapsulate sensor message parser as salable classes.
(2) generate timestamps file for each sensor message
(3) re-factor the whole tool
  • Loading branch information
gchen-apollo authored and xiaoxq committed Apr 9, 2019
1 parent de14d6d commit ec3b527
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 168 deletions.
13 changes: 12 additions & 1 deletion modules/tools/sensor_calibration/data_file_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
This is a bunch of classes to manage cyber record channel FileIO.
"""
import os
import re
import sys
import struct

Expand Down Expand Up @@ -53,6 +52,18 @@ def file_object(self):
def save_to_file(self, data):
raise NotImplementedError

class TimestampFileObject(FileObject):
"""class to handle sensor timestamp for each Apollo sensor channel"""
def __init__(self, file_path, operation='write', file_type='txt'):
super(TimestampFileObject, self).__init__(file_path,
operation, file_type)

def save_to_file(self, data):
if not isinstance(data, list):
raise ValueError("timestamps must be in a list")

for i, ts in enumerate(data):
self._file_object.write("%06d %.6f\n" %(i, ts))

class OdometryFileObject(FileObject):
"""class to handle gnss/odometry topic"""
Expand Down
197 changes: 30 additions & 167 deletions modules/tools/sensor_calibration/extract_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,14 @@
import sys
import shutil
import six

import numpy as np
import cv2
import pypcd

from cyber_py.record import RecordReader
from cyber.proto import record_pb2

from modules.drivers.proto import sensor_image_pb2
from modules.drivers.proto import pointcloud_pb2
from modules.localization.proto import gps_pb2
from data_file_object import *
from sensor_msg_extractor import *
#from scripts.record_bag import SMALL_TOPICS


SMALL_TOPICS = [
'/apollo/canbus/chassis',
'/apollo/canbus/chassis_detail',
Expand Down Expand Up @@ -86,10 +79,6 @@
CYBER_PATH = os.environ['CYBER_PATH']
CYBER_RECORD_HEADER_LENGTH = 2048

IMAGE_OBJ = sensor_image_pb2.Image()
POINTCLOUD_OBJ = pointcloud_pb2.PointCloud()
GPS_OBJ = gps_pb2.Gps()

def process_dir(path, operation):
"""Create or remove directory."""
try:
Expand All @@ -108,151 +97,12 @@ def process_dir(path, operation):

return True

def extract_camera_data(dest_dir, msg):
"""Extract camera file from message according to rate."""
# TODO(gchen-Apollo): change saving logic
cur_time_second = msg.timestamp
image = IMAGE_OBJ
image.ParseFromString(msg.message)
# Save image according to cyber format, defined in sensor camera proto.
# height = 4, image height, that is, number of rows.
# width = 5, image width, that is, number of columns.
# encoding = 6, as string, type is 'rgb8', 'bgr8' or 'gray'.
# step = 7, full row length in bytes.
# data = 8, actual matrix data in bytes, size is (step * rows).
# type = CV_8UC1 if image step is equal to width as gray, CV_8UC3
# if step * 3 is equal to width.
if image.encoding == 'rgb8' or image.encoding == 'bgr8':
if image.step != image.width * 3:
print('Image.step %d does not equal to Image.width %d * 3 for color image.'
% (image.step, image.width))
return False
elif image.encoding == 'gray' or image.encoding == 'y':
if image.step != image.width:
print('Image.step %d does not equal to Image.width %d or gray image.'
% (image.step, image.width))
return False
else:
print('Unsupported image encoding type: %s.' % image.encoding)
return False

channel_num = image.step / image.width
image_mat = np.fromstring(image.data, dtype=np.uint8).reshape(
(image.height, image.width, channel_num))

image_file = os.path.join(dest_dir, '{}.png'.format(cur_time_second))
# Save image in BGR oder
if image.encoding == 'rgb8':
cv2.imwrite(image_file, cv2.cvtColor(image_mat, cv2.COLOR_RGB2BGR))
else:
cv2.imwrite(image_file, image_mat)

def convert_xyzit_pb_to_array(xyz_i_t, data_type):
arr = np.zeros(len(xyz_i_t), dtype=data_type)
for i, point in enumerate(xyz_i_t):
arr[i] = (point.x, point.y, point.z,
point.intensity, point.timestamp)

return arr

def make_xyzit_point_cloud(xyz_i_t):
""" Make a pointcloud object from PointXYZIT message, as in Pointcloud.proto.
message PointXYZIT {
optional float x = 1 [default = nan];
optional float y = 2 [default = nan];
optional float z = 3 [default = nan];
optional uint32 intensity = 4 [default = 0];
optional uint64 timestamp = 5 [default = 0];
}
"""

md = {'version': .7,
'fields': ['x', 'y', 'z', 'intensity', 'timestamp'],
'count': [1, 1, 1, 1, 1],
'width': len(xyz_i_t),
'height': 1,
'viewpoint': [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
'points': len(xyz_i_t),
'type': ['F', 'F', 'F', 'U', 'F'],
'size': [4, 4, 4, 4, 8],
'data': 'binary_compressed'}

typenames = []
for t, s in zip(md['type'], md['size']):
np_type = pypcd.pcd_type_to_numpy_type[(t, s)]
typenames.append(np_type)

np_dtype = np.dtype(zip(md['fields'], typenames))
pc_data = convert_xyzit_pb_to_array(xyz_i_t, data_type=np_dtype)
pc = pypcd.PointCloud(md, pc_data)
return pc

def extract_pcd_data(dest_dir, msg):
"""
Transform protobuf PointXYZIT to standard PCL bin_compressed_file(*.pcd).
"""
cur_time_second = msg.timestamp
pointcloud = POINTCLOUD_OBJ
pointcloud.ParseFromString(msg.message)

pc_meta = make_xyzit_point_cloud(pointcloud.point)
pcd_file = os.path.join(dest_dir, '{}.pcd'.format(cur_time_second))
pypcd.save_point_cloud_bin_compressed(pc_meta, pcd_file)
# TODO(gchen-Apollo): add saint check
return True

def extract_gps_data(dest_dir, msg, out_msgs):
"""
Save gps information to bin file, to be fed into following tools
"""
if not isinstance(out_msgs, list):
raise ValueError("Gps/Odometry msg should be saved as a list, not %s"
% type(out_msgs))

gps = GPS_OBJ
gps.ParseFromString(msg.message)

# all double, except point_type is int32
ts = gps.header.timestamp_sec
point_type = 0
qw = gps.localization.orientation.qw
qx = gps.localization.orientation.qx
qy = gps.localization.orientation.qy
qz = gps.localization.orientation.qz
x = gps.localization.position.x
y = gps.localization.position.y
z = gps.localization.position.z
# save 9 values as a tuple, for eaisier struct packing during storage
out_msgs.append((ts, point_type, qw, qx, qy, qz, x, y, z))

print(gps)

# TODO(gchen-Apollo): build class, to encapsulate inner data structure
return out_msgs

def get_sensor_channel_list(record_file):
"""Get the channel list of sensors for calibration."""
record_reader = RecordReader(record_file)
return set(channel_name for channel_name in record_reader.get_channellist()
if 'sensor' in channel_name)


def extract_channel_data(output_path, msg, channel_msgs=None):
"""Process channel messages."""
channel_desc = msg.data_type
if channel_desc == 'apollo.drivers.Image':
extract_camera_data(output_path, msg)
elif channel_desc == 'apollo.drivers.PointCloud':
extract_pcd_data(output_path, msg)
elif channel_desc == 'apollo.localization.Gps':
channel_msgs[msg.topic] = extract_gps_data(output_path, msg,
channel_msgs[msg.topic])
else:
# TODO(LiuJie/gchen-Apollo): Handle binary data extraction.
print('Not implemented!')

return True, channel_msgs

def validate_channel_list(channels, dictionary):
ret = True
for channel in channels:
Expand All @@ -266,6 +116,19 @@ def validate_channel_list(channels, dictionary):
def in_range(v, s, e):
return True if v >= s and v <= e else False

def build_parser(channel, output_path):
parser = None
if channel.endswith("/image"):
parser = ImageParser(output_path=output_path, instance_saving=True)
elif channel.endswith("/PointCloud2"):
parser = PointCloudParser(output_path=output_path, instance_saving=True)
elif channel.endswith("/gnss/odometry"):
parser = GpsParser(output_path=output_path, instance_saving=False)
else:
raise ValueError("Not Support this channel type: %s" %channel)

return parser

def extract_data(record_files, output_path, channels,
start_timestamp, end_timestamp, extraction_rates):
"""
Expand Down Expand Up @@ -294,16 +157,19 @@ def extract_data(record_files, output_path, channels,
channel_success = {}
channel_occur_time = {}
channel_output_path = {}
channel_messages = {}
#channel_messages = {}
channel_parsers = {}
for channel in channels:
channel_success[channel] = True
channel_occur_time[channel] = -1
topic_name = channel.replace('/', '_')
channel_output_path[channel] = os.path.join(output_path, topic_name)
process_dir(channel_output_path[channel], operation='create')
channel_parsers[channel] =\
build_parser(channel, channel_output_path[channel])

if channel in SMALL_TOPICS:
channel_messages[channel] = list()
# if channel in SMALL_TOPICS:
# channel_messages[channel] = list()

for record_file in record_files:
record_reader = RecordReader(record_file)
Expand All @@ -319,8 +185,7 @@ def extract_data(record_files, output_path, channels,
if channel_occur_time[msg.topic] % extraction_rates[msg.topic] != 0:
continue

ret, _ = extract_channel_data(
channel_output_path[msg.topic], msg, channel_messages)
ret = channel_parsers[msg.topic].parse_sensor_message(msg)

# Calculate parsing statistics
if not ret:
Expand All @@ -334,8 +199,8 @@ def extract_data(record_files, output_path, channels,

# traverse the dict, if any channel topic stored as a list
# then save the list as a summary file, mostly binary file
for channel, messages in channel_messages.items():
save_msg_list_to_file(channel_output_path[channel], channel, messages)
for channel, parser in channel_parsers.items():
save_combined_messages_info(parser, channel)

# Logging statics about channel extraction
print('Extracted sensor channel number [%d] from record files: %s'
Expand All @@ -347,15 +212,13 @@ def extract_data(record_files, output_path, channels,

return True

def save_msg_list_to_file(out_path, channel, messages):
if 'odometry' in channel:
# generate file objects for small topics I/O
file_path = os.path.join(out_path, 'messages.bin')
odometry_file_obj = OdometryFileObject(file_path)
print(len(messages))
odometry_file_obj.save_to_file(messages)
else:
raise ValueError("saving function for {} not implemented".format(channel))
def save_combined_messages_info(parser, channel):
if not parser.save_messages_to_file():
raise ValueError("cannot save combined messages into single file for : %s " % channel)


if not parser.save_timestamps_to_file():
raise ValueError("cannot save tiemstamp info for %s " % channel)

def generate_compressed_file(input_path, input_name,
output_path, compressed_file='sensor_data'):
Expand Down
Loading

0 comments on commit ec3b527

Please sign in to comment.