Skip to content

Commit

Permalink
Tools: update extract_data
Browse files Browse the repository at this point in the history
       * only catch OSError at process_dir method

       * fix a couple of pylint and syntax issues
  • Loading branch information
freeHackOfJeff authored and ycool committed Apr 4, 2019
1 parent 46c9d32 commit d4189ba
Showing 1 changed file with 48 additions and 57 deletions.
105 changes: 48 additions & 57 deletions modules/tools/vehicle_calibration/sensor_calibration/extract_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
import re
import sys
import tarfile
import six

import cv2
import numpy as np
import cv2
import pypcd
import six

from cyber_py.record import RecordReader
from cyber.proto import record_pb2
Expand All @@ -45,10 +45,9 @@
CYBER_RECORD_HEADER_LENGTH = 2048
IMAGE_OBJ = sensor_image_pb2.Image()
POINTCLOUD_OBJ = pointcloud_pb2.PointCloud()

def process_dir(path, operation):
"""
Create or remove directory
"""
"""Create or remove directory."""
try:
if operation == 'create':
print("create folder %s" % path)
Expand All @@ -58,18 +57,16 @@ def process_dir(path, operation):
else:
print('Error! Unsupported operation %s for directory.' % operation)
return False
except (OSError, IOError) as e:
except OSError as e:
print('Failed to %s directory: %s. Error: %s' %
(operation, path, six.text_type(e)))
return False

return True

def extract_camera_data(dest_dir, msg):
"""
Extract camera file from message according to ratio
"""
#TODO: change saving logic
"""Extract camera file from message according to ratio."""
#TODO: change saving logic
cur_time_second = msg.timestamp
image = IMAGE_OBJ
image.ParseFromString(msg.message)
Expand All @@ -81,27 +78,27 @@ def extract_camera_data(dest_dir, msg):
# data = 8, actual matrix data in bytes, size is (step * rows).
# type = CV_8UC1 if image step is equal to width as gray, CV_8UC3
# if step * 3 is equal to width.
if image.encoding == "rgb8" or image.encoding == "bgr8":
if image.encoding == 'rgb8' or image.encoding == 'bgr8':
if image.step != image.width * 3:
print("Image.step %d does not equal Image.width %d * 3 for color image" %
(image.step, image.width))
print('Image.step %d does not equal to Image.width %d * 3 for color image.' %
(image.step, image.width))
return False
elif image.encoding == "gray" or image.encoding == "y":
elif image.encoding == 'gray' or image.encoding == 'y':
if image.step != image.width:
print("Image.step %d does not equal Image.width %d or gray image" %
(image.step, image.width))
print('Image.step %d does not equal Image.width %d or gray image.' %
(image.step, image.width))
return False
else:
print("Unsupported image encoding type %s" %encoding)
print('Unsupported image encoding type %s.' % image.encoding)
return False

channel_num = image.step / image.width
image_mat = np.fromstring(image.data, dtype=np.uint8).reshape(
(image.height, image.width, channel_num))
(image.height, image.width, channel_num))

image_file = os.path.join(dest_dir, '{}.png'.format(cur_time_second))
#python cv2 save image in BGR oder
if image.encoding == "rgb8":
# Save image in BGR oder
if image.encoding == 'rgb8':
cv2.imwrite(image_file, cv2.cvtColor(image_mat, cv2.COLOR_RGB2BGR))
else:
cv2.imwrite(image_file, image_mat)
Expand Down Expand Up @@ -138,7 +135,7 @@ def make_xyzit_point_cloud(xyz_i_t):

typenames = []
for t, s in zip(md['type'], md['size']):
np_type = pypcd.pcd_type_to_numpy_type[(t,s)]
np_type = pypcd.pcd_type_to_numpy_type[(t, s)]
typenames.append(np_type)

np_dtype = np.dtype(zip(md['fields'], typenames))
Expand All @@ -148,8 +145,7 @@ def make_xyzit_point_cloud(xyz_i_t):

def extract_pcd_data(dest_dir, msg):
"""
Extract PCD file
Transform protobuf PointXYZIT to standard PCL bin_compressed_file (*.pcd)
Transform protobuf PointXYZIT to standard PCL bin_compressed_file(*.pcd).
"""
cur_time_second = msg.timestamp
pointcloud = POINTCLOUD_OBJ
Expand All @@ -161,20 +157,16 @@ def extract_pcd_data(dest_dir, msg):
return True

def get_sensor_channel_list(record_file):
"""
Get the channel list of sensors for calibration
"""
"""Get the channel list of sensors for calibration."""
record_reader = RecordReader(record_file)
return set(channel_name for channel_name in record_reader.get_channellist()
if 'sensor' in channel_name)
if 'sensor' in channel_name)
# return [channel_name for channel_name in record_reader.get_channellist()
# if 'sensor' in channel_name]


def validate_record(record_file):
"""
Validate the record file
"""
"""Validate the record file."""
# Check the validity of a cyber record file according to header info.
record_reader = RecordReader(record_file)
header_msg = record_reader.get_headerstring()
Expand Down Expand Up @@ -213,9 +205,7 @@ def validate_record(record_file):


def extract_channel_data(output_path, msg):
"""
Process channel messages.
"""
"""Process channel messages."""
# timestamp = msg.timestamp / float(1e9)
# if abs(begin_time - timestamp) > 2:
channel_desc = msg.data_type
Expand All @@ -233,8 +223,8 @@ def validate_channel_list(channels, dictionary):
ret = True
for channel in channels:
if channel not in dictionary:
print("ERROR: channel %s does not exist in record\
sensor channels" % channel)
print('ERROR: channel %s does not exist in record \
sensor channels' % channel)
ret = False

return ret
Expand All @@ -243,29 +233,29 @@ def in_range(v, s, e):
return True if v >= s and v <= e else False

def extract_data(record_file, output_path, channel_list,
start_timestamp, end_timestamp, extraction_ratio):
start_timestamp, end_timestamp, extraction_ratio):
"""
Extract the desired channel messages if channel_list is specified.
Otherwise extract all sensor calibration messages according to
extraction ratio, 10% by default.
"""
#validate extration_ratio, and set it as an integer
# Validate extration_ratio, and set it as an integer.
if extraction_ratio < 1.0:
raise ValueError("Extraction rate must be a number no less than 1")
raise ValueError("Extraction rate must be a number greater than 1.")
extraction_ratio = np.floor(extraction_ratio)

sensor_channels = get_sensor_channel_list(record_file)
if len(channel_list) > 0 and validate_channel_list(
channel_list, sensor_channels) is False:
print("input channel list not valid")
if len(channel_list) > 0 and validate_channel_list(channel_list,
sensor_channels) is False:
print('Input channel list is invalid.')
return False

#If channel_list is empty(no input arguments), extract all the sensor channels
# Extract all the sensor channels if channel_list is empty(no input arguments).
print(sensor_channels)
if len(channel_list) == 0:
channel_list = sensor_channels

#Declare logging variables
# Declare logging variables
process_channel_success_num = len(channel_list)
process_channel_failure_num = 0
process_msg_failure_num = 0
Expand All @@ -276,40 +266,40 @@ def extract_data(record_file, output_path, channel_list,
for channel in channel_list:
channel_success_dict[channel] = True
channel_occur_time[channel] = -1
topic_name = channel.replace("/", "_")
channel_output_path[channel] = os.path.join(output_path,topic_name)
process_dir(channel_output_path[channel], operation="create")
topic_name = channel.replace('/', '_')
channel_output_path[channel] = os.path.join(output_path, topic_name)
process_dir(channel_output_path[channel], operation='create')

record_reader = RecordReader(record_file)
for msg in record_reader.read_messages():
if msg.topic in channel_list:
# only care about messages in certain time intervals
msg_timestamp_sec = msg.timestamp/1e9
# Only care about messages in certain time intervals
msg_timestamp_sec = msg.timestamp / 1e9
if not in_range(msg_timestamp_sec, start_timestamp, end_timestamp):
continue

channel_occur_time[msg.topic] += 1
# extract the topic according to extraction_ratio
# Extract the topic according to extraction_ratio
if channel_occur_time[msg.topic] % extraction_ratio != 0:
continue

ret = extract_channel_data(channel_output_path[msg.topic], msg)
# calculate parsing statistics
# Calculate parsing statistics
if ret is False:
process_msg_failure_num += 1
if channel_success_dict[msg.topic] is True:
channel_success_dict[msg.topic] = False;
channel_success_dict[msg.topic] = False
process_channel_failure_num += 1
process_channel_success_num -= 1
print('Failed to extract data from channel: %s' % msg.topic)

#Logging statics about channel extraction
# Logging statics about channel extraction
print('Extracted sensor channel number [%d] in record file: %s' %
(len(channel_list), record_file))
print('Successfully processed [%d] channels, and [%d] was failed.' %
(process_channel_success_num, process_channel_failure_num))
if process_msg_failure_num > 0:
print('Channel Extraction Failure number is: %d' % process_msg_failure_num)
print('Channel extraction failure number is: %d' % process_msg_failure_num)

return True

Expand Down Expand Up @@ -348,17 +338,18 @@ def main():
parser.add_argument("-c", "--channel_name", dest='channel_list', action="append",
default=[], help="list of channel_name that needs parsing.")
parser.add_argument("-s", "--start_timestamp", action="store", type=float,
default=np.finfo(np.float32).min, help="Specify the begining time to extract data information.")
default=np.finfo(np.float32).min,
help="Specify the begining time to extract data information.")
parser.add_argument("-e", "--end_timestamp", action="store", type=float,
default=np.finfo(np.float32).max, help="Specify the ending timestamp to extract data information.")
default=np.finfo(np.float32).max,
help="Specify the ending timestamp to extract data information.")
parser.add_argument("-r", "--extraction_ratio", action="store", type=int,
default=10, help="The output compressed file.")

args = parser.parse_args()

print("parsing the following channels:%s" % args.channel_list)
print('parsing the following channels: %s' % args.channel_list)

# TODO(Liujie): Add logger info to trace the extraction process
ret = validate_record(args.record_path)
if ret is False:
print('Failed to validate record file: %s' % args.record_path)
Expand Down

0 comments on commit d4189ba

Please sign in to comment.