Skip to content

Commit 89949c5

Browse files
authored
Minor fix for #40727 (#40929)
* fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent c830fc1 commit 89949c5

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

src/transformers/testing_utils.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3466,15 +3466,21 @@ def _get_test_info():
34663466
if test_frame is not None:
34673467
line_number = test_frame.lineno
34683468

3469-
# most inner (recent) to most outer () frames
3469+
# The frame of `patched` being called (the one and the only one calling `_get_test_info`)
3470+
# This is used to get the original method being patched in order to get the context.
3471+
frame_of_patched_obj = None
3472+
34703473
captured_frames = []
34713474
to_capture = False
3472-
# up to the test method being called
3475+
# From the most outer (i.e. python's `runpy.py`) frame to most inner frame (i.e. the frame of this method)
3476+
# Between `the test method being called` and `before entering `patched``.
34733477
for frame in reversed(stack_from_inspect):
34743478
if test_file in str(frame).replace(r"\\", "/"):
34753479
if "self" in frame.frame.f_locals and test_name == frame.frame.f_locals["self"]._testMethodName:
34763480
to_capture = True
3477-
elif "patched" in frame.frame.f_code.co_name:
3481+
# TODO: check simply with the name is not robust.
3482+
elif "patched" == frame.frame.f_code.co_name:
3483+
frame_of_patched_obj = frame
34783484
to_capture = False
34793485
break
34803486
if to_capture:
@@ -3486,11 +3492,17 @@ def _get_test_info():
34863492
tb_next = tb
34873493
test_traceback = tb
34883494

3495+
origin_method_being_patched = frame_of_patched_obj.frame.f_locals["orig_method"]
3496+
3497+
# An iterable of type `traceback.StackSummary` with each element of type `FrameSummary`
34893498
stack = traceback.extract_stack()
3499+
# The frame which calls `the original method being patched`
3500+
caller_frame = None
3501+
# From the most inner (i.e. recent) frame to the most outer frame
3502+
for frame in reversed(stack):
3503+
if origin_method_being_patched.__name__ in frame.line:
3504+
caller_frame = frame
34903505

3491-
# The frame that calls this patched method (it may not be the test method)
3492-
# -1: `_get_test_info`; -2: `patched_xxx`; -3: the caller to `patched_xxx`
3493-
caller_frame = stack[-3]
34943506
caller_path = os.path.relpath(caller_frame.filename)
34953507
caller_lineno = caller_frame.lineno
34963508

0 commit comments

Comments
 (0)