@@ -45,7 +45,7 @@ class GraphCaptureContext:
45
45
46
46
47
47
def _split_tensor_dict (
48
- tensor_dict : Dict [Any , Union [torch .Tensor , Any ]],
48
+ tensor_dict : Dict [str , Union [torch .Tensor , Any ]],
49
49
prefix : str = "" ) -> Tuple [List [Tuple [str , Any ]], List [torch .Tensor ]]:
50
50
"""Split the tensor dictionary into two parts:
51
51
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
@@ -473,11 +473,11 @@ def recv_object(self, src: int) -> Any:
473
473
474
474
def broadcast_tensor_dict (
475
475
self ,
476
- tensor_dict : Optional [Dict [Any , Union [torch .Tensor , Any ]]] = None ,
476
+ tensor_dict : Optional [Dict [str , Union [torch .Tensor , Any ]]] = None ,
477
477
src : int = 0 ,
478
478
group : Optional [ProcessGroup ] = None ,
479
479
metadata_group : Optional [ProcessGroup ] = None
480
- ) -> Optional [Dict [Any , Union [torch .Tensor , Any ]]]:
480
+ ) -> Optional [Dict [str , Union [torch .Tensor , Any ]]]:
481
481
"""Broadcast the input tensor dictionary.
482
482
NOTE: `src` is the local rank of the source rank.
483
483
"""
@@ -558,9 +558,9 @@ def broadcast_tensor_dict(
558
558
559
559
def send_tensor_dict (
560
560
self ,
561
- tensor_dict : Dict [Any , Union [torch .Tensor , Any ]],
561
+ tensor_dict : Dict [str , Union [torch .Tensor , Any ]],
562
562
dst : Optional [int ] = None
563
- ) -> Optional [Dict [Any , Union [torch .Tensor , Any ]]]:
563
+ ) -> Optional [Dict [str , Union [torch .Tensor , Any ]]]:
564
564
"""Send the input tensor dictionary.
565
565
NOTE: `dst` is the local rank of the source rank.
566
566
"""
@@ -599,7 +599,7 @@ def send_tensor_dict(
599
599
def recv_tensor_dict (
600
600
self ,
601
601
src : Optional [int ] = None
602
- ) -> Optional [Dict [Any , Union [torch .Tensor , Any ]]]:
602
+ ) -> Optional [Dict [str , Union [torch .Tensor , Any ]]]:
603
603
"""Recv the input tensor dictionary.
604
604
NOTE: `src` is the local rank of the source rank.
605
605
"""
@@ -615,15 +615,15 @@ def recv_tensor_dict(
615
615
assert src < self .world_size , f"Invalid src rank ({ src } )"
616
616
617
617
recv_metadata_list = self .recv_object (src = src )
618
- tensor_dict = {}
618
+ tensor_dict : Dict [ str , Any ] = {}
619
619
for key , value in recv_metadata_list :
620
620
if isinstance (value , TensorMetadata ):
621
621
tensor = torch .empty (value .size ,
622
622
dtype = value .dtype ,
623
623
device = value .device )
624
624
if tensor .numel () == 0 :
625
625
# Skip broadcasting empty tensors.
626
- tensor_dict [ key ] = tensor
626
+ _update_nested_dict ( tensor_dict , key , tensor )
627
627
continue
628
628
if tensor .is_cpu :
629
629
# use metadata_group for CPU tensors
@@ -633,9 +633,9 @@ def recv_tensor_dict(
633
633
else :
634
634
# use group for GPU tensors
635
635
torch .distributed .recv (tensor , src = src , group = group )
636
- tensor_dict [ key ] = tensor
636
+ _update_nested_dict ( tensor_dict , key , tensor )
637
637
else :
638
- tensor_dict [ key ] = value
638
+ _update_nested_dict ( tensor_dict , key , value )
639
639
return tensor_dict
640
640
641
641
def barrier (self ):
0 commit comments