72
72
_split_tensordict ,
73
73
_td_fields ,
74
74
_unravel_key_to_tuple ,
75
- _zip_strict ,
75
+ _zip_strict ,_to_escape_compile ,
76
76
cache ,
77
77
convert_ellipsis_to_idx ,
78
78
DeviceType ,
@@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
3521
3521
3522
3522
flat_size = []
3523
3523
start = 0
3524
+ sorting_index = 0
3524
3525
3525
3526
def add_single_value (value , key , metadata_dict , dtype , shape , flat_size ):
3526
- nonlocal start
3527
+ nonlocal start , sorting_index
3527
3528
n = value .element_size () * value .numel ()
3528
3529
if need_padding :
3529
3530
pad = n % 8
@@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
3541
3542
start ,
3542
3543
stop ,
3543
3544
pad ,
3545
+ flat_size [- 1 ],
3546
+ sorting_index ,
3544
3547
)
3548
+ sorting_index = sorting_index + 1
3545
3549
start = stop
3546
3550
3547
3551
def assign (
@@ -10441,6 +10445,7 @@ def to(self, *args, **kwargs) -> T:
10441
10445
pin_memory = non_blocking_pin ,
10442
10446
num_threads = num_threads ,
10443
10447
non_blocking = non_blocking ,
10448
+ compilable = is_dynamo_compiling (),
10444
10449
)
10445
10450
10446
10451
if non_blocking is None :
@@ -10498,14 +10503,42 @@ def to_pinmem(tensor, _to=to):
10498
10503
self ._sync_all ()
10499
10504
return result
10500
10505
10501
- def _to_consolidated (self , * , device , pin_memory , num_threads , non_blocking ):
10506
+ def _to_consolidated (
10507
+ self , * , device , pin_memory , num_threads , non_blocking , compilable
10508
+ ):
10502
10509
if num_threads is None :
10503
10510
# unspecified num_threads should mean 0
10504
10511
num_threads = 0
10512
+
10505
10513
storage = self ._consolidated ["storage" ]
10506
- if pin_memory :
10507
- storage = storage .pin_memory ()
10508
- storage_cast = storage .to (device , non_blocking = True )
10514
+
10515
+ storage_cast = _to_escape_compile (storage )
10516
+
10517
+ if compilable :
10518
+ result = self ._to_consolidated_compile (
10519
+ device = device , num_threads = num_threads , storage_cast = storage_cast
10520
+ )
10521
+ else :
10522
+ result = self ._to_consolidated_eager (
10523
+ device = device , num_threads = num_threads , storage_cast = storage_cast
10524
+ )
10525
+
10526
+ if non_blocking in (False , None ):
10527
+ if device .type == "cuda" and non_blocking is False :
10528
+ # sending to CUDA force sync
10529
+ cuda_device = device
10530
+ elif storage .device .type == "cuda" :
10531
+ # sending from cuda: need sync unless intentionally not asked for
10532
+ cuda_device = storage .device .type
10533
+ else :
10534
+ cuda_device = None
10535
+ if cuda_device is not None :
10536
+ torch .cuda .current_stream (cuda_device ).synchronize ()
10537
+
10538
+ return result
10539
+
10540
+ def _to_consolidated_eager (self , * , device , num_threads , storage_cast ):
10541
+
10509
10542
untyped_storage = storage_cast .untyped_storage ()
10510
10543
10511
10544
def set_ (x ):
@@ -10574,18 +10607,138 @@ def copy_dict(d):
10574
10607
}
10575
10608
10576
10609
result ._consolidated ["metadata" ] = copy_dict (self ._consolidated ["metadata" ])
10577
- if non_blocking in (False , None ):
10578
- if device .type == "cuda" and non_blocking is False :
10579
- # sending to CUDA force sync
10580
- cuda_device = device
10581
- elif storage .device .type == "cuda" :
10582
- # sending from cuda: need sync unless intentionally not asked for
10583
- cuda_device = storage .device .type
10584
- else :
10585
- cuda_device = None
10586
- if cuda_device is not None :
10587
- torch .cuda .current_stream (cuda_device ).synchronize ()
10610
+ return result
10611
+
10612
+ def _to_consolidated_compile (self , * , device , num_threads , storage_cast ):
10613
+
10614
+ def get_tensors_length (metadata , lengths = None , pos = None , keys = None , prefix = ()):
10615
+ root = False
10616
+ if lengths is None :
10617
+ lengths = []
10618
+ pos = []
10619
+ keys = []
10620
+ root = True
10621
+ for k , v in metadata ["leaves" ].items ():
10622
+ lengths .append (v [- 2 ])
10623
+ pos .append (v [- 1 ])
10624
+ keys .append (prefix + (k ,))
10625
+ for k , d in metadata .items ():
10626
+ if "leaves" in d :
10627
+ get_tensors_length (
10628
+ d , lengths = lengths , pos = pos , keys = keys , prefix = prefix + (k ,)
10629
+ )
10630
+ if root :
10631
+ # l = torch.empty(len(lengths), dtype=torch.long)
10632
+ # l[torch.as_tensor(pos)] = torch.as_tensor(lengths)
10633
+ out0 = [
10634
+ None ,
10635
+ ] * len (pos )
10636
+ out1 = [
10637
+ None ,
10638
+ ] * len (pos )
10639
+ for p , l , k in zip (pos , lengths , keys ):
10640
+ out0 [p ] = k
10641
+ out1 [p ] = l
10642
+ return out0 , out1
10643
+
10644
+ def split_storage (consolidated ):
10645
+ keys , splits = get_tensors_length (consolidated ["metadata" ])
10646
+ return dict (zip (keys , consolidated ["storage" ].split (splits )))
10647
+
10648
+ if num_threads is None :
10649
+ # unspecified num_threads should mean 0
10650
+ num_threads = 0
10651
+
10652
+ _consolidated = {"storage" : storage_cast }
10653
+ if "metadata" in self ._consolidated :
10654
+ # faster than deepcopy
10655
+ def copy_dict (d ):
10656
+ return {
10657
+ k : v if not isinstance (v , dict ) else copy_dict (v )
10658
+ for k , v in d .items ()
10659
+ }
10660
+
10661
+ _consolidated ["metadata" ] = copy_dict (self ._consolidated ["metadata" ])
10662
+
10663
+ slice_map = split_storage (_consolidated )
10664
+
10665
+ def view_as (src , dest ):
10666
+ return src .view (dest .dtype )[: dest .numel ()].view (dest .shape )
10588
10667
10668
+ def set_ (name , x ):
10669
+ if not isinstance (name , tuple ):
10670
+ name = (name ,)
10671
+ if x .is_nested :
10672
+ from torch ._subclasses .fake_tensor import FakeTensor
10673
+ from torch ._subclasses .functional_tensor import FunctionalTensor
10674
+ from torch .nested ._internal .nested_tensor import (
10675
+ _tensor_symint_registry ,
10676
+ NestedTensor ,
10677
+ )
10678
+ from torch .nested ._internal .ops import extract_kwargs
10679
+
10680
+ if x .layout != torch .jagged :
10681
+ raise RuntimeError (
10682
+ "to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10683
+ "Please raise an issue on GitHub."
10684
+ )
10685
+ kwargs = extract_kwargs (x )
10686
+ values = x ._values
10687
+ lengths = x ._lengths
10688
+ offsets = x ._offsets
10689
+ storage_offsets = slice_map [
10690
+ (
10691
+ * name [:- 1 ],
10692
+ "<NJT_OFFSETS>" + name [- 1 ],
10693
+ )
10694
+ ]
10695
+ kwargs ["offsets" ] = view_as (storage_offsets , offsets )
10696
+ if lengths is not None :
10697
+ storage_lengths = slice_map [
10698
+ (
10699
+ * name [:- 1 ],
10700
+ "<NJT_LENGTHS>" + name [- 1 ],
10701
+ )
10702
+ ]
10703
+ kwargs ["lengths" ] = view_as (storage_lengths , lengths )
10704
+ ragged_source = lengths
10705
+ else :
10706
+ ragged_source = offsets
10707
+ new_thing = kwargs .get ("lengths" , kwargs .get ("offsets" ))
10708
+ if isinstance (new_thing , (FakeTensor , FunctionalTensor )):
10709
+ from torch ._subclasses .functional_tensor import (
10710
+ mb_unwrap_functional_tensor ,
10711
+ )
10712
+
10713
+ # Temporary hack until we have the union find
10714
+ tgt = mb_unwrap_functional_tensor (new_thing )
10715
+ src = mb_unwrap_functional_tensor (ragged_source )
10716
+ tgt .nested_int_memo = src .nested_int_memo
10717
+ else :
10718
+ _tensor_symint_registry [new_thing ] = _tensor_symint_registry [
10719
+ ragged_source
10720
+ ]
10721
+
10722
+ storage_values = slice_map [
10723
+ (
10724
+ * name [:- 1 ],
10725
+ "<NJT_VALUES>" + name [- 1 ],
10726
+ )
10727
+ ]
10728
+ return NestedTensor (
10729
+ view_as (storage_values , values ),
10730
+ ** kwargs ,
10731
+ )
10732
+ return view_as (slice_map [name ], x )
10733
+
10734
+ result = self ._fast_apply (
10735
+ set_ ,
10736
+ device = torch .device (device ),
10737
+ num_threads = num_threads ,
10738
+ named = True ,
10739
+ nested_keys = True ,
10740
+ )
10741
+ result ._consolidated = _consolidated
10589
10742
return result
10590
10743
10591
10744
def _sync_all (self ):
0 commit comments