Skip to content

Commit

Permalink
Extend lidarseg plot to plot histograms for both number of points an…
Browse files Browse the repository at this point in the history
…d number of scan-wise instances per class (nutonomy#658)

* Update method to render histograms for lidarseg and panoptic

* Extend lidarseg plot to plot histograms for both number of points and number of scan-wise instances per class

* Remove unused comment about MathTextSciFormatter
  • Loading branch information
whyekit-motional authored Sep 14, 2021
1 parent 84ee1f3 commit cf2f128
Showing 1 changed file with 124 additions and 84 deletions.
208 changes: 124 additions & 84 deletions python-sdk/nuscenes/lidarseg/class_histogram.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
import time
from typing import List, Tuple
from typing import Dict

import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter, ScalarFormatter
import matplotlib.ticker as mticker
import matplotlib.transforms as mtrans
import numpy as np

from nuscenes import NuScenes
from nuscenes.panoptic.panoptic_utils import get_frame_panoptic_instances, get_panoptic_instances_stats
from nuscenes.utils.color_map import get_colormap
from nuscenes.utils.data_io import load_bin_file


def truncate_class_name(class_name: str) -> str:
Expand Down Expand Up @@ -56,40 +57,57 @@ def truncate_class_name(class_name: str) -> str:
return string_mapper[class_name]


def render_lidarseg_histogram(nusc: NuScenes,
sort_by: str = 'count_desc',
chart_title: str = None,
x_label: str = None,
y_label: str = "Lidar points (logarithmic)",
y_log_scale: bool = True,
verbose: bool = True,
font_size: int = 20,
save_as_img_name: str = None) -> None:
def render_histogram(nusc: NuScenes,
sort_by: str = 'count_desc',
verbose: bool = True,
font_size: int = 20,
save_as_img_name: str = None) -> None:
"""
Render a histogram for the given nuScenes split.
Render two histograms for the given nuScenes split. The top histogram depicts the number of scan-wise instances
for each class, while the bottom histogram depicts the number of points for each class.
:param nusc: A nuScenes object.
:param sort_by: How to sort the classes:
:param sort_by: How to sort the classes to display in the plot (note that the x-axis, where the class names will be
displayed on, is shared by the two histograms):
- count_desc: Sort the classes by the number of points belonging to each class, in descending order.
- count_asc: Sort the classes by the number of points belonging to each class, in ascending order.
- name: Sort the classes by alphabetical order.
- index: Sort the classes by their indices.
:param chart_title: Title to display on the histogram.
:param x_label: Title to display on the x-axis of the histogram.
:param y_label: Title to display on the y-axis of the histogram.
:param y_log_scale: Whether to use log scale on the y-axis.
:param verbose: Whether to display plot in a window after rendering.
:param font_size: Size of the font to use for the histogram.
:param save_as_img_name: Path (including image name and extension) to save the histogram as.
:param verbose: Whether to display the plot in a window after rendering.
:param font_size: Size of the font to use for the plot.
:param save_as_img_name: Path (including image name and extension) to save the plot as.
"""

print('Calculating stats for nuScenes-lidarseg...')
start_time = time.time()

# Get the statistics for the given nuScenes split.
class_names, counts = get_lidarseg_stats(nusc, sort_by=sort_by)

print('Calculated stats for {} point clouds in {:.1f} seconds.\n====='.format(
len(nusc.lidarseg), time.time() - start_time))
lidarseg_num_points_per_class = get_lidarseg_num_points_per_class(nusc, sort_by=sort_by)
panoptic_num_instances_per_class = get_panoptic_num_instances_per_class(nusc, sort_by=sort_by)

# Align the two dictionaries by adding entries for the stuff classes to panoptic_num_instances_per_class; the
# instance count for each of these stuff classes is 0.
panoptic_num_instances_per_class_tmp = dict()
for class_name in lidarseg_num_points_per_class.keys():
num_instances_for_class = panoptic_num_instances_per_class.get(class_name, 0)
panoptic_num_instances_per_class_tmp[class_name] = num_instances_for_class
panoptic_num_instances_per_class = panoptic_num_instances_per_class_tmp

# Define some settings for each histogram.
histograms_config = dict({
'panoptic': {
'y_values': list(panoptic_num_instances_per_class.values()),
'y_label': 'No. of instances',
'y_scale': 'log'
},
'lidarseg': {
'y_values': list(lidarseg_num_points_per_class.values()),
'y_label': 'No. of lidar points',
'y_scale': 'log'
}
})

# Ensure the same set of class names are used for all histograms.
assert lidarseg_num_points_per_class.keys() == panoptic_num_instances_per_class.keys(), \
'Error: There are {} classes for lidarseg, but {} classes for panoptic.'.format(
len(lidarseg_num_points_per_class.keys()), len(panoptic_num_instances_per_class.keys()))
class_names = list(lidarseg_num_points_per_class.keys())

# Create an array with the colors to use.
cmap = get_colormap()
Expand All @@ -99,57 +117,49 @@ def render_lidarseg_histogram(nusc: NuScenes,
class_names = [truncate_class_name(cn) for cn in class_names]

# Start a plot.
fig, ax = plt.subplots(figsize=(16, 9))
plt.margins(x=0.005) # Add some padding to the left and right limits of the x-axis for aesthetics.
ax.set_axisbelow(True) # Ensure that axis ticks and gridlines will be below all other ploy elements.
ax.yaxis.grid(color='white', linewidth=2) # Show horizontal gridlines.
ax.set_facecolor('#eaeaf2') # Set background of plot.
ax.spines['top'].set_visible(False) # Remove top border of plot.
ax.spines['right'].set_visible(False) # Remove right border of plot.
ax.spines['bottom'].set_visible(False) # Remove bottom border of plot.
ax.spines['left'].set_visible(False) # Remove left border of plot.

# Plot the histogram.
ax.bar(class_names, counts, color=colors)
assert len(class_names) == len(ax.get_xticks()), \
'There are {} classes, but {} are shown on the x-axis'.format(len(class_names), len(ax.get_xticks()))

# Format the x-axis.
ax.set_xlabel(x_label, fontsize=font_size)
ax.set_xticklabels(class_names, rotation=45, horizontalalignment='right',
fontweight='light', fontsize=font_size)

# Shift the class names on the x-axis slightly to the right for aesthetics.
trans = mtrans.Affine2D().translate(10, 0)
for t in ax.get_xticklabels():
t.set_transform(t.get_transform() + trans)

# Format the y-axis.
ax.set_ylabel(y_label, fontsize=font_size)
ax.set_yticklabels(counts, size=font_size)

# Transform the y-axis to log scale.
if y_log_scale:
ax.set_yscale("log")

# Display the y-axis using nice scientific notation.
formatter = ScalarFormatter(useOffset=False, useMathText=True)
ax.yaxis.set_major_formatter(
FuncFormatter(lambda x, pos: "${}$".format(formatter._formatSciNotation('%1.10e' % x))))

if chart_title:
ax.set_title(chart_title, fontsize=font_size)
fig, axes = plt.subplots(nrows=2, sharex=True, figsize=(16, 9))
for ax in axes:
ax.margins(x=0.005) # Add some padding to the left and right limits of the x-axis for aesthetics.
ax.set_axisbelow(True) # Ensure that axis ticks and gridlines will be below all other ploy elements.
ax.yaxis.grid(color='white', linewidth=2) # Show horizontal gridlines.
ax.set_facecolor('#eaeaf2') # Set background of plot.
ax.spines['top'].set_visible(False) # Remove top border of plot.
ax.spines['right'].set_visible(False) # Remove right border of plot.
ax.spines['bottom'].set_visible(False) # Remove bottom border of plot.
ax.spines['left'].set_visible(False) # Remove left border of plot.

# Plot the histograms.
for i, (histogram, config) in enumerate(histograms_config.items()):
axes[i].bar(class_names, config['y_values'], color=colors)
assert len(class_names) == len(axes[i].get_xticks()), \
'There are {} classes, but {} are shown on the x-axis'.format(len(class_names), len(axes[i].get_xticks()))

# Format the x-axis.
axes[i].set_xticklabels(class_names, rotation=45, horizontalalignment='right',
fontweight='light', fontsize=font_size)

# Shift the class names on the x-axis slightly to the right for aesthetics.
trans = mtrans.Affine2D().translate(10, 0)
for t in axes[i].get_xticklabels():
t.set_transform(t.get_transform() + trans)

# Format the y-axis.
axes[i].set_ylabel(config['y_label'], fontsize=font_size)
axes[i].set_yticklabels(config['y_values'], size=font_size)
axes[i].set_yscale(config['y_scale'])

if config['y_scale'] == 'linear':
axes[i].yaxis.set_major_formatter(mticker.FormatStrFormatter('%.1e'))

if save_as_img_name:
fig = ax.get_figure()
plt.tight_layout()
fig.savefig(save_as_img_name)

if verbose:
plt.show()


def get_lidarseg_stats(nusc: NuScenes, sort_by: str = 'count_desc') -> Tuple[List[str], List[int]]:
def get_lidarseg_num_points_per_class(nusc: NuScenes, sort_by: str = 'count_desc') -> Dict[str, int]:
"""
Get the number of points belonging to each class for the given nuScenes split.
:param nusc: A NuScenes object.
Expand All @@ -158,7 +168,8 @@ def get_lidarseg_stats(nusc: NuScenes, sort_by: str = 'count_desc') -> Tuple[Lis
- count_asc: Sort the classes by the number of points belonging to each class, in ascending order.
- name: Sort the classes by alphabetical order.
- index: Sort the classes by their indices.
:return: A list of class names and a list of the corresponding number of points for each class.
:return: A dictionary whose keys are the class names and values are the corresponding number of points for each
class.
"""

# Initialize an array of zeroes, one for each class name.
Expand All @@ -173,27 +184,56 @@ def get_lidarseg_stats(nusc: NuScenes, sort_by: str = 'count_desc') -> Tuple[Lis
for class_idx, class_count in zip(ii, indices[ii]):
lidarseg_counts[class_idx] += class_count

lidarseg_counts_dict = dict()
num_points_per_class = dict()
for i in range(len(lidarseg_counts)):
lidarseg_counts_dict[nusc.lidarseg_idx2name_mapping[i]] = lidarseg_counts[i]
num_points_per_class[nusc.lidarseg_idx2name_mapping[i]] = lidarseg_counts[i]

if sort_by == 'count_desc':
out = sorted(lidarseg_counts_dict.items(), key=lambda item: item[1], reverse=True)
num_points_per_class = dict(sorted(num_points_per_class.items(), key=lambda item: item[1], reverse=True))
elif sort_by == 'count_asc':
out = sorted(lidarseg_counts_dict.items(), key=lambda item: item[1])
num_points_per_class = dict(sorted(num_points_per_class.items(), key=lambda item: item[1]))
elif sort_by == 'name':
out = sorted(lidarseg_counts_dict.items())
num_points_per_class = dict(sorted(num_points_per_class.items()))
elif sort_by == 'index':
out = lidarseg_counts_dict.items()
num_points_per_class = dict(num_points_per_class.items())
else:
raise Exception('Error: Invalid sorting mode {}. '
'Only `count_desc`, `count_asc`, `name` or `index` are valid.'.format(sort_by))

# Get frequency counts of each class in the lidarseg dataset.
class_names = []
counts = []
for class_name, count in out:
class_names.append(class_name)
counts.append(count)
return num_points_per_class


def get_panoptic_num_instances_per_class(nusc: NuScenes, sort_by: str = 'count_desc') -> Dict[str, int]:
"""
Get the number of scan-wise instances belonging to each class for the given nuScenes split.
:param nusc: A NuScenes object.
:param sort_by: How to sort the classes:
- count_desc: Sort the classes by the number of instances belonging to each class, in descending order.
- count_asc: Sort the classes by the number of instances belonging to each class, in ascending order.
- name: Sort the classes by alphabetical order.
- index: Sort the classes by their indices.
:return: A dictionary whose keys are the class names and values are the corresponding number of scan-wise instances
for each class.
"""
sequence_wise_instances_per_class = dict()
for instance in nusc.instance:
instance_class = nusc.get('category', instance['category_token'])['name']
if instance_class not in sequence_wise_instances_per_class.keys():
sequence_wise_instances_per_class[instance_class] = 0
sequence_wise_instances_per_class[instance_class] += instance['nbr_annotations']

if sort_by == 'count_desc':
sequence_wise_instances_per_class = dict(
sorted(sequence_wise_instances_per_class.items(), key=lambda item: item[1], reverse=True))
elif sort_by == 'count_asc':
sequence_wise_instances_per_class = dict(
sorted(sequence_wise_instances_per_class.items(), key=lambda item: item[1]))
elif sort_by == 'name':
sequence_wise_instances_per_class = dict(sorted(sequence_wise_instances_per_class.items()))
elif sort_by == 'index':
sequence_wise_instances_per_class = dict(sequence_wise_instances_per_class.items())
else:
raise Exception('Error: Invalid sorting mode {}. '
'Only `count_desc`, `count_asc`, `name` or `index` are valid.'.format(sort_by))

return class_names, counts
return sequence_wise_instances_per_class

0 comments on commit cf2f128

Please sign in to comment.