Skip to content

Commit a6db8b0

Browse files
committed
Smile
1 parent deb61fa commit a6db8b0

File tree

2 files changed

+244
-0
lines changed

2 files changed

+244
-0
lines changed

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
-e .
22

3+
# As of Apr 2025, we use only safe_load() which is available since PyYAML's first release 3.01
4+
# https://github.com/yaml/pyyaml/blob/3.01/lib/yaml/__init__.py#L71
5+
pyyaml<7
6+
37
# python-dotenv 1.0+ no longer supports Python 3.7
48
python-dotenv>=0.21,<2
59

tests/smile.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
#!/usr/bin/env python3
2+
"""
3+
MSAL Feature Test Runner
4+
Interprets testcase file(s) to create and execute test cases using MSAL.
5+
6+
Initially created by the following prompt:
7+
Write a python implementation that can read content from feature.yml, create variables whose names are defined in the "arrange" mapping's keys, and the variables' value are derived from the "arrange" mapping's value; interpret those value as if they are python snippet using MSAL library.
8+
"""
9+
import os
10+
import sys
11+
import logging
12+
from contextlib import contextmanager
13+
from typing import Dict, Any, List, Optional
14+
15+
import yaml
16+
import msal
17+
import requests
18+
19+
20+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
21+
logger = logging.getLogger(__name__)
22+
23+
class SmileTestRunner:
24+
25+
def __init__(self, testcase_url: str):
26+
self.testcase_url = testcase_url
27+
self.test_spec = None
28+
self.variables = {}
29+
30+
def load_feature(self) -> Dict[str, Any]:
31+
"""Load and validate the feature file."""
32+
try:
33+
with requests.get(self.testcase_url) as response:
34+
response.raise_for_status()
35+
self.test_spec = yaml.safe_load(response.text)
36+
37+
# Basic validation
38+
if not isinstance(self.test_spec, dict):
39+
raise ValueError("Feature file must contain a valid YAML dictionary")
40+
41+
if self.test_spec.get('type') != 'MSAL Test':
42+
raise ValueError("Feature file must have type 'MSAL Test'")
43+
44+
return self.test_spec
45+
except Exception as e:
46+
logger.error(f"Error loading feature file: {str(e)}")
47+
sys.exit(1)
48+
49+
@contextmanager
50+
def setup_environment(self):
51+
"""Set up the environment variables specified in the feature file."""
52+
original_env = os.environ.copy()
53+
54+
try:
55+
# Set environment variables
56+
if 'env' in self.test_spec and isinstance(self.test_spec['env'], dict):
57+
for key, value in self.test_spec['env'].items():
58+
os.environ[key] = str(value)
59+
logger.debug(f"Set environment variable {key}={value}")
60+
yield
61+
finally:
62+
# Restore original environment
63+
os.environ.clear()
64+
os.environ.update(original_env)
65+
66+
def arrange(self):
67+
"""Create variables based on the arrange section."""
68+
arrange_spec = self.test_spec.get('arrange', {})
69+
if not isinstance(arrange_spec, dict):
70+
raise ValueError("Arrange section must be a dictionary")
71+
for var_name, value_spec in arrange_spec.items():
72+
logger.debug(f"Creating variable '{var_name}' with {value_spec}")
73+
self.variables[var_name] = self._create_instance(value_spec)
74+
75+
def _create_instance(self, spec: Dict[str, Any]) -> Any:
76+
"""Create an instance based on the specification."""
77+
if not isinstance(spec, dict) or len(spec) != 1:
78+
raise ValueError(f"Invalid specification format: {spec}")
79+
80+
class_name, params = next(iter(spec.items()))
81+
82+
# Handle different MSAL classes
83+
if class_name == "ManagedIdentityClient":
84+
return msal.ManagedIdentityClient(http_client=requests.Session(), **params)
85+
elif class_name == "PublicClientApplication":
86+
return self._create_public_client_app(params)
87+
elif class_name == "ConfidentialClientApplication":
88+
return self._create_confidential_client_app(params)
89+
else:
90+
raise ValueError(f"Unsupported class: {class_name}")
91+
92+
def _create_public_client_app(self, params: Dict[str, Any]) -> Any:
93+
"""Create a PublicClientApplication instance."""
94+
if not params or 'client_id' not in params:
95+
raise ValueError("PublicClientApplication requires client_id")
96+
97+
client_id = params.get('client_id')
98+
authority = params.get('authority')
99+
logger.debug(f"Creating PublicClientApplication with client_id: {client_id}, authority: {authority}")
100+
101+
kwargs = {'client_id': client_id}
102+
if authority:
103+
kwargs['authority'] = authority
104+
105+
return msal.PublicClientApplication(**kwargs)
106+
107+
def _create_confidential_client_app(self, params: Dict[str, Any]) -> Any:
108+
"""Create a ConfidentialClientApplication instance."""
109+
if not params or 'client_id' not in params or 'client_credential' not in params:
110+
raise ValueError("ConfidentialClientApplication requires client_id and client_credential")
111+
112+
client_id = params.get('client_id')
113+
client_credential = params.get('client_credential')
114+
authority = params.get('authority')
115+
logger.debug(f"Creating ConfidentialClientApplication with client_id: {client_id}, authority: {authority}")
116+
117+
kwargs = {'client_id': client_id, 'client_credential': client_credential}
118+
if authority:
119+
kwargs['authority'] = authority
120+
121+
return msal.ConfidentialClientApplication(**kwargs)
122+
123+
def execute_steps(self) -> bool:
124+
"""Execute the test steps, returns whether all steps passed."""
125+
steps = self.test_spec.get('steps', [])
126+
passed = 0
127+
for i, step in enumerate(steps):
128+
logger.debug(f"Executing step {i+1}/{len(steps)}")
129+
if 'act' in step:
130+
result = self._execute_action(step['act'])
131+
if 'assert' in step:
132+
if self._validate_assertions(result, step['assert']):
133+
passed += 1
134+
logger.info(f"{passed} of {len(steps)} step(s) passed")
135+
return passed == len(steps)
136+
137+
def _execute_action(self, act_spec: Dict[str, Any]) -> Any:
138+
"""Execute an action based on the specification."""
139+
if not isinstance(act_spec, dict) or len(act_spec) != 1:
140+
raise ValueError(f"Invalid action specification: {act_spec}")
141+
142+
action_str, params = next(iter(act_spec.items()))
143+
144+
# Parse the action string (e.g., "app1.AcquireToken")
145+
parts = action_str.split('.')
146+
if len(parts) != 2:
147+
raise ValueError(f"Invalid action format: {action_str}")
148+
149+
var_name = parts[0]
150+
method_name = { # Map the method names in yml to actual method names
151+
"AcquireTokenForManagedIdentity": "acquire_token_for_client",
152+
}.get(parts[1])
153+
154+
if method_name is None:
155+
raise ValueError(f"Unsupported method: {parts[1]}")
156+
157+
if var_name not in self.variables:
158+
raise ValueError(f"Variable '{var_name}' not found")
159+
160+
instance = self.variables[var_name]
161+
if not hasattr(instance, method_name):
162+
raise ValueError(f"Method '{method_name}' not found on {var_name}")
163+
164+
method = getattr(instance, method_name)
165+
166+
# Convert parameters to kwargs
167+
kwargs = params if params else {}
168+
169+
# Execute the method with parameters
170+
logger.info(f"Calling {var_name}.{method_name} with {kwargs}")
171+
return method(**kwargs)
172+
173+
def _validate_assertions(self, result: Any, assertions: Dict[str, Any]) -> bool:
174+
"""Validate the assertions against the result."""
175+
logger.info(f"Validating assertions: {assertions}")
176+
for key, expected_value in assertions.items():
177+
if key not in result:
178+
logger.error(f"Assertion failed: '{key}' not found in result {result}")
179+
return False # Failed
180+
actual_value = result[key]
181+
if actual_value != expected_value:
182+
logger.error(f"Assertion failed: expected {key}='{expected_value}', got '{actual_value}'")
183+
return False # Failed
184+
else:
185+
logger.debug(f"Assertion passed: {key}='{actual_value}'")
186+
return True # Passed
187+
188+
def run(self) -> bool:
189+
"""Run the entire test, returns whether it passed."""
190+
self.load_feature()
191+
192+
with self.setup_environment():
193+
self.arrange()
194+
result = self.execute_steps()
195+
if result:
196+
logger.info(f"Test case {self.testcase_url} passed")
197+
else:
198+
logger.error(f"Test case {self.testcase_url} failed")
199+
return result
200+
201+
202+
def run_testcases(testcases_url: str) -> bool:
203+
try:
204+
response = requests.get(testcases_url)
205+
response.raise_for_status()
206+
passed = 0
207+
testcases = response.json().get("testcases", [])
208+
for url in testcases:
209+
try:
210+
if SmileTestRunner(url).run():
211+
passed += 1
212+
except Exception as e:
213+
logger.error(f"Test case {url} failed: {e}")
214+
(logger.info if passed == len(testcases) else logger.error)(
215+
f"Passed {passed} of {len(testcases)} test cases"
216+
)
217+
return passed == len(testcases)
218+
except requests.RequestException as e:
219+
logger.error(f"Failed to fetch test cases from {url}: {e}")
220+
return False
221+
222+
def main():
223+
import argparse
224+
parser = argparse.ArgumentParser(description="MSAL Feature Test Runner")
225+
group = parser.add_mutually_exclusive_group(required=True)
226+
group.add_argument("--testcase", help="URL for a single test case")
227+
group.add_argument("--batch", help="URL for a batch of test cases in JSON format")
228+
args = parser.parse_args()
229+
230+
if args.testcase:
231+
logger.setLevel(logging.DEBUG)
232+
success = SmileTestRunner(args.testcase).run()
233+
elif args.batch:
234+
logger.setLevel(logging.INFO)
235+
success = run_testcases(args.batch)
236+
237+
sys.exit(0 if success else 1)
238+
239+
if __name__ == "__main__":
240+
main()

0 commit comments

Comments
 (0)