@@ -394,35 +394,39 @@ def _aggregate_named(self, func, *args, **kwargs):
394394 def transform (self , func , * args , ** kwargs ):
395395 func = self ._get_cython_func (func ) or func
396396
397- if isinstance (func , str ):
398- if not (func in base .transform_kernel_whitelist ):
399- msg = "'{func}' is not a valid function name for transform(name)"
400- raise ValueError (msg .format (func = func ))
401- if func in base .cythonized_kernels :
402- # cythonized transform or canned "agg+broadcast"
403- return getattr (self , func )(* args , ** kwargs )
404- else :
405- # If func is a reduction, we need to broadcast the
406- # result to the whole group. Compute func result
407- # and deal with possible broadcasting below.
408- return self ._transform_fast (
409- lambda : getattr (self , func )(* args , ** kwargs ), func
410- )
397+ if not isinstance (func , str ):
398+ return self ._transform_general (func , * args , ** kwargs )
399+
400+ elif func not in base .transform_kernel_whitelist :
401+ msg = f"'{ func } ' is not a valid function name for transform(name)"
402+ raise ValueError (msg )
403+ elif func in base .cythonized_kernels :
404+ # cythonized transform or canned "agg+broadcast"
405+ return getattr (self , func )(* args , ** kwargs )
411406
412- # reg transform
407+ # If func is a reduction, we need to broadcast the
408+ # result to the whole group. Compute func result
409+ # and deal with possible broadcasting below.
410+ result = getattr (self , func )(* args , ** kwargs )
411+ return self ._transform_fast (result , func )
412+
413+ def _transform_general (self , func , * args , ** kwargs ):
414+ """
415+ Transform with a non-str `func`.
416+ """
413417 klass = self ._selected_obj .__class__
418+
414419 results = []
415- wrapper = lambda x : func (x , * args , ** kwargs )
416420 for name , group in self :
417421 object .__setattr__ (group , "name" , name )
418- res = wrapper (group )
422+ res = func (group , * args , ** kwargs )
419423
420424 if isinstance (res , (ABCDataFrame , ABCSeries )):
421425 res = res ._values
422426
423427 indexer = self ._get_index (name )
424- s = klass (res , indexer )
425- results .append (s )
428+ ser = klass (res , indexer )
429+ results .append (ser )
426430
427431 # check for empty "results" to avoid concat ValueError
428432 if results :
@@ -433,7 +437,7 @@ def transform(self, func, *args, **kwargs):
433437 result = Series ()
434438
435439 # we will only try to coerce the result type if
436- # we have a numeric dtype, as these are *always* udfs
440+ # we have a numeric dtype, as these are *always* user-defined funcs
437441 # the cython take a different path (and casting)
438442 dtype = self ._selected_obj .dtype
439443 if is_numeric_dtype (dtype ):
@@ -443,17 +447,14 @@ def transform(self, func, *args, **kwargs):
443447 result .index = self ._selected_obj .index
444448 return result
445449
446- def _transform_fast (self , func , func_nm ) -> Series :
450+ def _transform_fast (self , result , func_nm : str ) -> Series :
447451 """
448452 fast version of transform, only applicable to
449453 builtin/cythonizable functions
450454 """
451- if isinstance (func , str ):
452- func = getattr (self , func )
453-
454455 ids , _ , ngroup = self .grouper .group_info
455456 cast = self ._transform_should_cast (func_nm )
456- out = algorithms .take_1d (func () ._values , ids )
457+ out = algorithms .take_1d (result ._values , ids )
457458 if cast :
458459 out = self ._try_cast (out , self .obj )
459460 return Series (out , index = self .obj .index , name = self .obj .name )
@@ -1333,21 +1334,21 @@ def transform(self, func, *args, **kwargs):
13331334 # optimized transforms
13341335 func = self ._get_cython_func (func ) or func
13351336
1336- if isinstance (func , str ):
1337- if not (func in base .transform_kernel_whitelist ):
1338- msg = "'{func}' is not a valid function name for transform(name)"
1339- raise ValueError (msg .format (func = func ))
1340- if func in base .cythonized_kernels :
1341- # cythonized transformation or canned "reduction+broadcast"
1342- return getattr (self , func )(* args , ** kwargs )
1343- else :
1344- # If func is a reduction, we need to broadcast the
1345- # result to the whole group. Compute func result
1346- # and deal with possible broadcasting below.
1347- result = getattr (self , func )(* args , ** kwargs )
1348- else :
1337+ if not isinstance (func , str ):
13491338 return self ._transform_general (func , * args , ** kwargs )
13501339
1340+ elif func not in base .transform_kernel_whitelist :
1341+ msg = f"'{ func } ' is not a valid function name for transform(name)"
1342+ raise ValueError (msg )
1343+ elif func in base .cythonized_kernels :
1344+ # cythonized transformation or canned "reduction+broadcast"
1345+ return getattr (self , func )(* args , ** kwargs )
1346+
1347+ # If func is a reduction, we need to broadcast the
1348+ # result to the whole group. Compute func result
1349+ # and deal with possible broadcasting below.
1350+ result = getattr (self , func )(* args , ** kwargs )
1351+
13511352 # a reduction transform
13521353 if not isinstance (result , DataFrame ):
13531354 return self ._transform_general (func , * args , ** kwargs )
@@ -1358,16 +1359,18 @@ def transform(self, func, *args, **kwargs):
13581359 if not result .columns .equals (obj .columns ):
13591360 return self ._transform_general (func , * args , ** kwargs )
13601361
1361- return self ._transform_fast (result , obj , func )
1362+ return self ._transform_fast (result , func )
13621363
1363- def _transform_fast (self , result : DataFrame , obj : DataFrame , func_nm ) -> DataFrame :
1364+ def _transform_fast (self , result : DataFrame , func_nm : str ) -> DataFrame :
13641365 """
13651366 Fast transform path for aggregations
13661367 """
13671368 # if there were groups with no observations (Categorical only?)
13681369 # try casting data to original dtype
13691370 cast = self ._transform_should_cast (func_nm )
13701371
1372+ obj = self ._obj_with_exclusions
1373+
13711374 # for each col, reshape to to size of original frame
13721375 # by take operation
13731376 ids , _ , ngroup = self .grouper .group_info
0 commit comments