Skip to content
Merged
63 changes: 59 additions & 4 deletions src/sasctl/pzmm/write_json_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,7 @@ def create_requirements_json(
cls,
model_path: Union[str, Path, None] = Path.cwd(),
output_path: Union[str, Path, None] = None,
create_requirements_txt: bool = False,
) -> Union[dict, None]:
"""
Searches the model directory for Python scripts and pickle files and
Expand All @@ -1636,14 +1637,22 @@ def create_requirements_json(
environment.

When provided with an output_path argument, this function outputs a JSON file
named "requirements.json". Otherwise, a list of dicts is returned.
named "requirements.json". If create_requirements_txt is True, it will also
create a requirements.txt file. Otherwise, a list of dicts is returned.

Note: requirements.txt file is only created when both output_path and
create_requirements_txt are specified.

Parameters
----------
model_path : str or pathlib.Path, optional
The path to a Python project, by default the current working directory.
output_path : str or pathlib.Path, optional
The path for the output requirements.json file. The default value is None.
create_requirements_txt : bool, optional
Whether to also create a requirements.txt file in addition to the
requirements.json file. This is useful for SAS Event Stream Processing
environments. The default value is False.

Returns
-------
Expand All @@ -1662,11 +1671,57 @@ def create_requirements_json(
package_list = list(set(list(_flatten(package_list))))
package_list = cls.remove_standard_library_packages(package_list)
package_and_version = cls.get_local_package_version(package_list)

# Identify packages with missing versions
missing_package_versions = [
item[0] for item in package_and_version if not item[1]
]


IMPORT_TO_INSTALL_MAPPING = {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the addition of this import to install mapping create some issues with backwards compatibility? I noticed that in the example files it directs users to manually change the import names for packages like sklearn.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I think we're good here, since it was a non-programmatic ask. If they go hunting to change something and find it's fixed, I doubt we'll get complaints.

# Data Science & ML Core
"sklearn": "scikit-learn",
"skimage": "scikit-image",
"cv2": "opencv-python",
"PIL": "Pillow",
# Data Formats & Parsing
"yaml": "PyYAML",
"bs4": "beautifulsoup4",
"docx": "python-docx",
"pptx": "python-pptx",
# Date & Time Utilities
"dateutil": "python-dateutil",
# Database Connectors
"MySQLdb": "MySQL-python",
"psycopg2": "psycopg2-binary",
# System & Platform
"win32api": "pywin32",
"win32com": "pywin32",
# Scientific Libraries
"Bio": "biopython",
}

# Map import names to their corresponding package installation names
package_and_version = [
(IMPORT_TO_INSTALL_MAPPING.get(name, name), version) for name, version in package_and_version
]

if create_requirements_txt:
requirements_txt = ""
if missing_package_versions:
requirements_txt += "# Warning- The existence and/or versions for the following packages could not be determined:\n"
requirements_txt += "# " + ", ".join(missing_package_versions) + "\n"

for package, version in package_and_version:
if version:
requirements_txt += f"{package}=={version}\n"

if output_path:
with open( # skipcq: PTC-W6004
Path(output_path) / "requirements.txt", "w"
) as file:
file.write(requirements_txt)

# Create a list of dicts related to each package or warning
json_dicts = []
if missing_package_versions:
Expand Down Expand Up @@ -1800,16 +1855,16 @@ def find_imports(file_path: Union[str, Path]) -> List[str]:
file_text = file.read()
# Parse the file to get the abstract syntax tree representation
tree = ast.parse(file_text)
modules = []
modules = set()

# Walk through each node in the ast to find import calls
for node in ast.walk(tree):
# Determine parent module for `from * import *` calls
if isinstance(node, ast.ImportFrom):
modules.append(node.module)
modules.add(node.module.split(".")[0])
elif isinstance(node, ast.Import):
for name in node.names:
modules.append(name.name)
modules.add(name.name.split(".")[0])

modules = list(set(modules))
try:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_write_json_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def test_create_requirements_json(change_dir):
dtc = dtc.fit(x_train, y_train)
with open(tmp_dir / "DecisionTreeClassifier.pickle", "wb") as pkl_file:
pickle.dump(dtc, pkl_file)
jf.create_requirements_json(tmp_dir, Path(tmp_dir))
jf.create_requirements_json(tmp_dir, Path(tmp_dir), True)
assert (Path(tmp_dir) / "requirements.json").exists()

json_dict = jf.create_requirements_json(tmp_dir)
Expand Down
Loading