@@ -426,3 +426,55 @@ def test_array_ufunc_series_defer():
426426
427427 tm .assert_series_equal (r1 , expected )
428428 tm .assert_series_equal (r2 , expected )
429+
430+
431+ def test_groupby_agg ():
432+ # Ensure that the result of agg is inferred to be decimal dtype
433+ # https://github.com/pandas-dev/pandas/issues/29141
434+
435+ data = make_data ()[:5 ]
436+ df = pd .DataFrame (
437+ {"id1" : [0 , 0 , 0 , 1 , 1 ], "id2" : [0 , 1 , 0 , 1 , 1 ], "decimals" : DecimalArray (data )}
438+ )
439+
440+ # single key, selected column
441+ expected = pd .Series (to_decimal ([data [0 ], data [3 ]]))
442+ result = df .groupby ("id1" )["decimals" ].agg (lambda x : x .iloc [0 ])
443+ tm .assert_series_equal (result , expected , check_names = False )
444+ result = df ["decimals" ].groupby (df ["id1" ]).agg (lambda x : x .iloc [0 ])
445+ tm .assert_series_equal (result , expected , check_names = False )
446+
447+ # multiple keys, selected column
448+ expected = pd .Series (
449+ to_decimal ([data [0 ], data [1 ], data [3 ]]),
450+ index = pd .MultiIndex .from_tuples ([(0 , 0 ), (0 , 1 ), (1 , 1 )]),
451+ )
452+ result = df .groupby (["id1" , "id2" ])["decimals" ].agg (lambda x : x .iloc [0 ])
453+ tm .assert_series_equal (result , expected , check_names = False )
454+ result = df ["decimals" ].groupby ([df ["id1" ], df ["id2" ]]).agg (lambda x : x .iloc [0 ])
455+ tm .assert_series_equal (result , expected , check_names = False )
456+
457+ # multiple columns
458+ expected = pd .DataFrame ({"id2" : [0 , 1 ], "decimals" : to_decimal ([data [0 ], data [3 ]])})
459+ result = df .groupby ("id1" ).agg (lambda x : x .iloc [0 ])
460+ tm .assert_frame_equal (result , expected , check_names = False )
461+
462+
463+ def test_groupby_agg_ea_method (monkeypatch ):
464+ # Ensure that the result of agg is inferred to be decimal dtype
465+ # https://github.com/pandas-dev/pandas/issues/29141
466+
467+ def DecimalArray__my_sum (self ):
468+ return np .sum (np .array (self ))
469+
470+ monkeypatch .setattr (DecimalArray , "my_sum" , DecimalArray__my_sum , raising = False )
471+
472+ data = make_data ()[:5 ]
473+ df = pd .DataFrame ({"id" : [0 , 0 , 0 , 1 , 1 ], "decimals" : DecimalArray (data )})
474+ expected = pd .Series (to_decimal ([data [0 ] + data [1 ] + data [2 ], data [3 ] + data [4 ]]))
475+
476+ result = df .groupby ("id" )["decimals" ].agg (lambda x : x .values .my_sum ())
477+ tm .assert_series_equal (result , expected , check_names = False )
478+ s = pd .Series (DecimalArray (data ))
479+ result = s .groupby (np .array ([0 , 0 , 0 , 1 , 1 ])).agg (lambda x : x .values .my_sum ())
480+ tm .assert_series_equal (result , expected , check_names = False )
0 commit comments