Skip to content

Commit

Permalink
Add memory visualization tool (#24)
Browse files Browse the repository at this point in the history
* feat: add mem visualization tool

* fix: remove unused comment

* feat: change figure style

* feat: add legend
  • Loading branch information
lausannel committed Sep 23, 2024
1 parent 27eccca commit 4bd6eb3
Showing 1 changed file with 341 additions and 0 deletions.
341 changes: 341 additions & 0 deletions tools/plot_mem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,341 @@
import argparse
import re

import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np

BUFFER_ASSIGNMENT_SUFFIX = "buffer-assignment.txt"
BUFFER_LIVE_RANGE_KEYWORD = "BufferLiveRange:"
TENSOR_PATTERN = r"value: <\d+ ([^@>]+)@0> \(size\=(\d+),offset=(\d+)\):"
ALLOCATION_PATTERN = r"allocation (\d+): size (\d+)"
LIVE_RANGE_PATTERN = r"\s*(\S+):(\d+)-(\d+)\s*\n?"


def parse_args() -> argparse.Namespace:
'''
Parses command-line arguments and check validity.
Returns:
args (argparse.Namespace): Parsed command line arguments
'''
parser = argparse.ArgumentParser(description='''
Plot tensor lifecycle and buffer usage, when given a buffer assignment file from XLA.
After dumping the buffer assignment file with "--xla_dump_hlo_as_text" flag, you can use this script to visualize the buffer usage during the execution.
The script will plot the lifecycle of each tensor, and the buffer usage over time.
The peak memory usage will be annotated on the plot.
For example:\n
python plot-mem.py --input=module_0000.SyncTensorsGraph.105381.sm_8.0_gpu_after_optimizations-buffer-assignment.txt
The script will analyze the buffer assignment file and generate a plot named "tensor_lifecycle.png".
''')
parser.add_argument(
'--input', type=str, help='The file to be parsed', required=True)
args = parser.parse_args()
args.input = args.input.strip()

if not args.input.endswith(BUFFER_ASSIGNMENT_SUFFIX):
raise ValueError(
f"The input file should end with '{BUFFER_ASSIGNMENT_SUFFIX}'.")

return args


def read_file(file_name: str) -> str:
'''
Read the file and return the content.
Args:
file_name (str): Path to the file to be read
Returns:
str: Content of the file
'''
with open(file_name, 'r', encoding='utf-8') as file:
return file.read()


def get_all_tensors(file_content: str) -> dict[str, dict]:
'''
Extract information about all tensors from the file content. The pattern is:
value: <89591 custom-call.87.0{2} @0> (size=33554432,offset=0): bf16[16384,8,128]{2,1,0}
Args:
file_content (str): Content of the file
Returns:
dict[str, dict]: Dictionary of tensors, where the key is the tensor name and the value is a dictionary containing the size, offset, and allocation_id
'''
pattern = TENSOR_PATTERN # (\w+\[\d+(?:,\d+)*\])"
allocation_pattern = ALLOCATION_PATTERN
results = {}
offset_lb = 0
offset_ub = 0
allocation_id = 0

lines = file_content.splitlines()
for line in lines:
allocation_matches = re.search(allocation_pattern, line)
if allocation_matches:
allocation_id, size = allocation_matches.groups()
offset_lb = offset_ub
offset_ub = offset_lb + int(size)
matches = re.findall(pattern, line)
for match in matches:
tensor_name, size, offset = match
results[tensor_name.strip()] = {
'size': size,
'offset': int(offset) + offset_lb,
'allocation_id': allocation_id
}
print(f"[INFO] Total {len(results)} tensors are found.")
return results


def get_live_range(file_content: str, tensors: dict) -> tuple[dict, set]:
'''
Retrieves the live ranges for the tensors from the file content. Live range pattern is:
BufferLiveRange:
reduce-scatter.8199.1{}:2037-2091
reduce-scatter.8708.1{}:2145-2198
Args:
file_content (str): Content of the file
tensors (dict): Dictionary of tensors, where the key is the tensor name and the value is a dictionary containing the size, offset, and allocation_id
Returns:
dict: Dictionary of tensors with live range information added
diff (set): Set of tensors that are not matched with live range pattern
'''
keyword = BUFFER_LIVE_RANGE_KEYWORD
keyword_pos = file_content.find(keyword)
if keyword_pos == -1:
raise ValueError(f"{keyword} not found.")

file_content = file_content[keyword_pos + len(keyword):]
pattern = LIVE_RANGE_PATTERN
matches = re.findall(pattern, file_content)
i = 0
matched_name = set()
for match in matches:
name, live_range_start, live_range_end = match
if name not in tensors:
name = name[:-2]
if name not in tensors:
continue
if name in tensors:
tensors[name]['start'] = live_range_start
tensors[name]['end'] = live_range_end
matched_name.add(name)
i += 1
print(f"[INFO] Total {i} tensors' live range are found.")
diff = set(tensors.keys()) - matched_name
return tensors, diff


def get_allocation_group_by_size(allocations: dict) -> dict:
'''
Groups allocations by their sizes.
Args:
allocations (dict): Dictionary of allocations, where the key is the allocation ID and the value is a dictionary containing the size and start time
Returns:
dict: Dictionary of allocations grouped by their sizes, the key is the index and the value is a dictionary containing the size, start time and count
'''
allocations_group_by_size = {}
index = 0
sorted_allocations = sorted(
allocations.items(), key=lambda item: int(item[0]))

prev_size = None

for allocation_id, allocation_info in sorted_allocations:
size = int(allocation_info['size'])
start = int(allocation_info['start'])

if size != prev_size:
index += 1
allocations_group_by_size[index] = {
'size': size,
'start': start,
'count': 1
}
else:
allocations_group_by_size[index]['count'] += 1

prev_size = size

return allocations_group_by_size


def plot_tensor_lifecycle(tensors: dict, allocations: dict):
'''
Plots the lifecycle of tensors and buffer usage over time.
Args:
tensors (dict): Dictionary containing tensor information.
allocations (dict): Dictionary containing allocation information.
'''
fig, ax = plt.subplots(figsize=(12, 6))
max_time = 0
max_buffer_size = 0
buffer_usage = {}

for tensor_name, tensor_info in tensors.items():
try:
size = int(tensor_info['size'])
offset = int(tensor_info['offset'])
start = int(tensor_info['start'])
end = int(tensor_info['end'])
rect = patches.Rectangle((start, offset),
end - start,
size,
fill=True,
facecolor='skyblue',
edgecolor='skyblue',
alpha=0.7)
ax.add_patch(rect)

for t in range(start, end + 1):
if t not in buffer_usage:
buffer_usage[t] = 0
buffer_usage[t] += size

max_time = max(max_time, end)
max_buffer_size = max(max_buffer_size, offset + size)
except:
continue

allocations_group_by_size = get_allocation_group_by_size(allocations)
for index, allocation_info in sorted(
allocations_group_by_size.items(), key=lambda x: x[0]):
size = allocation_info['size']
start = allocation_info['start']
count = allocation_info['count']
# only show allocations larger than 512MB
if size * count < 512 * 1024 * 1024:
continue
rect = patches.Rectangle((0, start),
max_time,
size * count,
fill=False,
edgecolor='red',
linewidth=1.5,
alpha=1.0)
ax.add_patch(rect)
ax.text(
1.02,
start + size * count / 2,
f"Buffer Count: {count}\nSize: {size} bytes ({size/1024/1024/1024:.2f} GB)\nTotal: {size * count} bytes ({size * count/1024/1024/1024:.2f} GB)",
fontsize=6,
ha='left',
va='center',
transform=ax.get_yaxis_transform(),
)

peak_time = max(buffer_usage, key=buffer_usage.get)
peak_usage = buffer_usage[peak_time]

times = sorted(buffer_usage.keys())
usages = [buffer_usage[t] for t in times]
ax.plot(
times,
usages,
color='Slateblue',
linewidth=2,
alpha=0.6,
label='Actual Total Tensor Mem Usage')
ax.plot([], [],
color='skyblue',
linewidth=2,
alpha=0.6,
label='Tensor Lifetime and Mem Usage')
ax.plot([], [],
color='red',
linewidth=2,
alpha=0.6,
label='Buffer Allocation')

ax.annotate(
f"Time: {peak_time}\nActual Usage: {peak_usage} bytes ({peak_usage/1024/1024/1024:.2f} GB)",
xy=(peak_time, peak_usage),
xytext=(peak_time, peak_usage + max(usages) * 0.1),
arrowprops=dict(facecolor='black', shrink=0.05),
fontsize=8,
ha='center',
va='bottom')

ax.set_xlim(0, max_time)
ax.set_ylim(0, max(max_buffer_size, peak_usage * 1.4))
ax.set_xlabel('Time')
ax.set_ylabel('Buffer Size (bytes)')
ax.set_title('Tensor Lifecycle and Buffer Usage', pad=10)

y_ticks = ax.get_yticks()
ax.set_yticks(y_ticks)
ax.set_yticklabels([f'{y /1024 / 1024 / 1024 :.2f} GB' for y in y_ticks])

ax.axhline(y=max_buffer_size, color='red', linestyle='--', alpha=0.7)

ax.annotate(
f'Allocated Size: {max_buffer_size/1024/1024/1024:.2f} GB',
xy=(max_time, max_buffer_size),
xytext=(max_time, max_buffer_size + max_buffer_size * 0.1),
arrowprops=dict(facecolor='red', shrink=0.05),
fontsize=8,
ha='center',
va='bottom',
)

ax.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.legend(fontsize=6)
plt.savefig('tensor_lifecycle.png', dpi=1200)
print("[INFO] tensor_lifecycle.png saved successfully.")


def get_all_allocation(file_content: str) -> tuple:
'''
Extracts information about all allocations from the file content.
allocation 1: size 525338624,
Args:
file_content (str): Content of the buffer assignment file.
Returns:
tuple: Dictionaries containing allocation information and total allocation size.
'''
pattern = ALLOCATION_PATTERN

matches = re.findall(pattern, file_content)
results = {}
start = 0
for match in matches:
allocation_id, size = match
results[allocation_id] = {'size': size, 'start': start}
start += int(size)

total = 0
for key, value in results.items():
total += int(value['size'])
return results, total


if __name__ == '__main__':
args = parse_args()
content = read_file(args.input)

tensors = get_all_tensors(content)
tensors, diff = get_live_range(content, tensors)
allocations, total = get_all_allocation(content)
print(
f"[INFO] Total allocation size: {total} bytes ({total/1024/1024/1024:.2f} GB)"
)

if diff:
print(f"[WARNING] {len(diff)} tensors' live range are not found.")

plot_tensor_lifecycle(tensors, allocations)

0 comments on commit 4bd6eb3

Please sign in to comment.