@@ -45,14 +45,17 @@ class GraphCaptureContext:
45
45
46
46
47
47
def _split_tensor_dict (
48
- tensor_dict : Dict [Any , Union [torch .Tensor , Any ]]
49
- ) -> Tuple [List [Tuple [str , Any ]], List [torch .Tensor ]]:
48
+ tensor_dict : Dict [Any , Union [torch .Tensor , Any ]],
49
+ prefix : str = "" ) -> Tuple [List [Tuple [str , Any ]], List [torch .Tensor ]]:
50
50
"""Split the tensor dictionary into two parts:
51
51
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
52
52
by its metadata.
53
53
2. A list of tensors.
54
+
55
+ If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
56
+ metadata will be "key1%key2".
54
57
"""
55
- metadata_list = []
58
+ metadata_list : List [ Tuple [ str , Any ]] = []
56
59
tensor_list = []
57
60
for key , value in tensor_dict .items ():
58
61
if isinstance (value , torch .Tensor ):
@@ -62,13 +65,31 @@ def _split_tensor_dict(
62
65
# receiving side will set the device index.
63
66
device = value .device .type
64
67
metadata_list .append (
65
- (key , TensorMetadata (device , value .dtype , value .size ())))
68
+ (prefix + key , TensorMetadata (device , value .dtype ,
69
+ value .size ())))
66
70
tensor_list .append (value )
71
+ elif isinstance (value , dict ):
72
+ if len (value ) == 0 :
73
+ metadata_list .append ((prefix + key , value ))
74
+ inner_metadata_list , inner_tensor_list = _split_tensor_dict (
75
+ value , prefix + key + "%" )
76
+ metadata_list .extend (inner_metadata_list )
77
+ tensor_list .extend (inner_tensor_list )
67
78
else :
68
- metadata_list .append ((key , value ))
79
+ metadata_list .append ((prefix + key , value ))
69
80
return metadata_list , tensor_list
70
81
71
82
83
+ def _update_nested_dict (nested_dict , flattened_key , value ):
84
+ key_splits = flattened_key .split ("%" )
85
+ cur_dict = nested_dict
86
+ for k in key_splits [:- 1 ]:
87
+ if k not in cur_dict :
88
+ cur_dict [k ] = {}
89
+ cur_dict = cur_dict [k ]
90
+ cur_dict [key_splits [- 1 ]] = value
91
+
92
+
72
93
class GroupCoordinator :
73
94
"""
74
95
PyTorch ProcessGroup wrapper for a group of processes.
@@ -512,7 +533,7 @@ def broadcast_tensor_dict(
512
533
device = value .device )
513
534
if tensor .numel () == 0 :
514
535
# Skip broadcasting empty tensors.
515
- tensor_dict [ key ] = tensor
536
+ _update_nested_dict ( tensor_dict , key , tensor )
516
537
continue
517
538
if tensor .is_cpu :
518
539
# use metadata_group for CPU tensors
@@ -528,9 +549,9 @@ def broadcast_tensor_dict(
528
549
group = group ,
529
550
async_op = True )
530
551
async_handles .append (handle )
531
- tensor_dict [ key ] = tensor
552
+ _update_nested_dict ( tensor_dict , key , tensor )
532
553
else :
533
- tensor_dict [ key ] = value
554
+ _update_nested_dict ( tensor_dict , key , value )
534
555
for async_handle in async_handles :
535
556
async_handle .wait ()
536
557
return tensor_dict
0 commit comments