Skip to content

Commit d3e81a9

Browse files
committed
Add ability to record cassettes with refresh_tokens
1 parent d29412d commit d3e81a9

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

tests/conftest.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
import socket
55
import time
66
from base64 import b64encode
7+
from functools import wraps
78
from sys import platform
89
from urllib.parse import quote_plus
910

1011
import betamax
1112
import pytest
13+
from betamax.cassette.cassette import Cassette, dispatch_hooks
1214
from betamax.serializers import JSONSerializer
1315

1416

@@ -55,7 +57,7 @@ def filter_access_token(interaction, current_cassette):
5557
x: env_default(x)
5658
for x in (
5759
"auth_code client_id client_secret password redirect_uri test_subreddit"
58-
" user_agent username"
60+
" user_agent username refresh_token"
5961
).split()
6062
}
6163

@@ -83,6 +85,28 @@ def serialize(self, cassette_data):
8385
config.define_cassette_placeholder(f"<{key.upper()}>", value)
8486

8587

88+
def add_init_hook(original_init):
89+
"""Wrap an __init__ method to also call some hooks."""
90+
91+
@wraps(original_init)
92+
def wrapper(self, *args, **kwargs):
93+
original_init(self, *args, **kwargs)
94+
dispatch_hooks("after_init", self)
95+
96+
return wrapper
97+
98+
99+
Cassette.__init__ = add_init_hook(Cassette.__init__)
100+
101+
102+
def init_hook(cassette):
103+
if cassette.is_recording():
104+
pytest.set_up_record() # dynamically defined in __init__.py
105+
106+
107+
Cassette.hooks["after_init"].append(init_hook)
108+
109+
86110
class Placeholders:
87111
def __init__(self, _dict):
88112
self.__dict__ = _dict

tests/integration/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class IntegrationTest:
1616

1717
def setup(self):
1818
"""Setup runs before all test cases."""
19+
self._overrode_reddit_setup = True
1920
self.setup_reddit()
2021
self.setup_betamax()
2122

@@ -31,7 +32,11 @@ def setup_betamax(self):
3132
# Require tests to explicitly disable read_only mode.
3233
self.reddit.read_only = True
3334

35+
pytest.set_up_record = self.set_up_record # used in conftest.py
36+
3437
def setup_reddit(self):
38+
self._overrode_reddit_setup = False
39+
3540
self._session = requests.Session()
3641

3742
self.reddit = Reddit(
@@ -43,6 +48,17 @@ def setup_reddit(self):
4348
username=pytest.placeholders.username,
4449
)
4550

51+
def set_up_record(self):
52+
if not self._overrode_reddit_setup:
53+
if pytest.placeholders.refresh_token != "placeholder_refresh_token":
54+
self.reddit = Reddit(
55+
requestor_kwargs={"session": self._session},
56+
client_id=pytest.placeholders.client_id,
57+
client_secret=pytest.placeholders.client_secret,
58+
user_agent=pytest.placeholders.user_agent,
59+
refresh_token=pytest.placeholders.refresh_token,
60+
)
61+
4662
def use_cassette(self, cassette_name=None, **kwargs):
4763
"""Use a cassette. The cassette name is dynamically generated.
4864

0 commit comments

Comments
 (0)