@@ -86,36 +86,35 @@ class BatchedBrownianTree:
8686 """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
8787
8888 def __init__ (self , x , t0 , t1 , seed = None , ** kwargs ):
89- self .cpu_tree = True
90- if "cpu" in kwargs :
91- self .cpu_tree = kwargs .pop ("cpu" )
89+ self .cpu_tree = kwargs .pop ("cpu" , True )
9290 t0 , t1 , self .sign = self .sort (t0 , t1 )
93- w0 = kwargs .get ('w0' , torch .zeros_like (x ))
91+ w0 = kwargs .pop ('w0' , None )
92+ if w0 is None :
93+ w0 = torch .zeros_like (x )
94+ self .batched = False
9495 if seed is None :
95- seed = torch .randint (0 , 2 ** 63 - 1 , []).item ()
96- self .batched = True
97- try :
98- assert len (seed ) == x .shape [0 ]
96+ seed = (torch .randint (0 , 2 ** 63 - 1 , ()).item (),)
97+ elif isinstance (seed , (tuple , list )):
98+ if len (seed ) != x .shape [0 ]:
99+ raise ValueError ("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size." )
100+ self .batched = True
99101 w0 = w0 [0 ]
100- except TypeError :
101- seed = [seed ]
102- self .batched = False
103- if self .cpu_tree :
104- self .trees = [torchsde .BrownianTree (t0 .cpu (), w0 .cpu (), t1 .cpu (), entropy = s , ** kwargs ) for s in seed ]
105102 else :
106- self .trees = [torchsde .BrownianTree (t0 , w0 , t1 , entropy = s , ** kwargs ) for s in seed ]
103+ seed = (seed ,)
104+ if self .cpu_tree :
105+ t0 , w0 , t1 = t0 .detach ().cpu (), w0 .detach ().cpu (), t1 .detach ().cpu ()
106+ self .trees = tuple (torchsde .BrownianTree (t0 , w0 , t1 , entropy = s , ** kwargs ) for s in seed )
107107
108108 @staticmethod
109109 def sort (a , b ):
110110 return (a , b , 1 ) if a < b else (b , a , - 1 )
111111
112112 def __call__ (self , t0 , t1 ):
113113 t0 , t1 , sign = self .sort (t0 , t1 )
114+ device , dtype = t0 .device , t0 .dtype
114115 if self .cpu_tree :
115- w = torch .stack ([tree (t0 .cpu ().float (), t1 .cpu ().float ()).to (t0 .dtype ).to (t0 .device ) for tree in self .trees ]) * (self .sign * sign )
116- else :
117- w = torch .stack ([tree (t0 , t1 ) for tree in self .trees ]) * (self .sign * sign )
118-
116+ t0 , t1 = t0 .detach ().cpu ().float (), t1 .detach ().cpu ().float ()
117+ w = torch .stack ([tree (t0 , t1 ) for tree in self .trees ]).to (device = device , dtype = dtype ) * (self .sign * sign )
119118 return w if self .batched else w [0 ]
120119
121120
0 commit comments