Skip to content

Commit b543a32

Browse files
Improve API for stability
1 parent 9a80de0 commit b543a32

File tree

5 files changed

+57
-44
lines changed

5 files changed

+57
-44
lines changed

daswow/daswow_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
1313
MODELS_PATH = os.path.join(SCRIPT_DIR, "models")
1414

15+
download_models_from_github_release()
1516

1617
class Preprocessing:
1718
# init. set dataframe to be processed
@@ -90,8 +91,6 @@ def __init__(self, nb_path, models_path=MODELS_PATH):
9091
cf = CellFeatures()
9192
self.df = cf.get_cell_features_nb(nb_path)
9293

93-
download_models_from_github_release()
94-
9594
self.preprocesser = Preprocessing(self.df)
9695
self.model = joblib.load(f"{models_path}/rf_code_scaled.pkl")
9796
self.tfidf = joblib.load(f"{models_path}/tfidf_vectorizer.pkl")

daswow/model_download.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@ def download_models_from_github_release(repo_owner="secure-software-engineering"
2525
os.makedirs(download_path)
2626

2727
# check if files already exist and remove from the list
28+
assets_to_download = []
2829
for asset_name in asset_names:
29-
if os.path.exists(os.path.join(download_path, asset_name)):
30-
asset_names.remove(asset_name)
30+
if not os.path.exists(os.path.join(download_path, asset_name)):
31+
assets_to_download.append(asset_name)
32+
33+
if not assets_to_download:
34+
return "Models already exist"
3135

3236
# API endpoint to get release info
3337
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{release_tag}"
@@ -41,7 +45,7 @@ def download_models_from_github_release(repo_owner="secure-software-engineering"
4145
for asset in release_data['assets']:
4246
print(asset['name']) # Add this line
4347

44-
for asset_name in asset_names:
48+
for asset_name in assets_to_download:
4549
# Find the download URL of the asset
4650
asset_url = None
4751
for asset in release_data['assets']:
@@ -53,6 +57,7 @@ def download_models_from_github_release(repo_owner="secure-software-engineering"
5357
raise ValueError(f"Asset '{asset_name}' not found in the release.")
5458

5559
# Download the file
60+
print(f"Downloading model file: {asset['name']}")
5661
response = requests.get(asset_url, stream=True)
5762
response.raise_for_status()
5863

headergen/server.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from fastapi.middleware.gzip import GZipMiddleware
99
from fastapi.responses import JSONResponse
1010

11-
from framework_models import get_high_level_phase
11+
from framework_models import get_high_level_phase, DASWOW_PHASES
1212
from headergen import headergen
1313

1414
app = FastAPI()
@@ -38,44 +38,53 @@
3838
@app.post("/get_analysis_notebook/")
3939
async def get_analysis(file: UploadFile = File(...)):
4040
"""Upload a notebook file, analyze chunks of it, and add metadata."""
41-
# Save the uploaded file to the uploads directory
42-
file_location = f"{UPLOAD_DIR}/{file.filename}"
43-
44-
async with aiofiles.open(file_location, "wb") as f:
45-
content = await file.read()
46-
await f.write(content)
47-
48-
# Load the notebook
49-
async with aiofiles.open(file_location, "r", encoding="utf-8") as file:
50-
notebook_content = await file.read()
51-
notebook = nbformat.reads(notebook_content, as_version=4)
52-
53-
# Perform analysis on the uploaded notebook
5441
try:
55-
analysis_meta = headergen.start_headergen(
56-
file_location, OUTPUT_DIR, debug_mode=True
57-
)
42+
# Save the uploaded file to the uploads directory
43+
file_location = f"{UPLOAD_DIR}/{file.filename}"
44+
45+
async with aiofiles.open(file_location, "wb") as f:
46+
content = await file.read()
47+
await f.write(content)
48+
49+
# Load the notebook
50+
async with aiofiles.open(file_location, "r", encoding="utf-8") as file:
51+
notebook_content = await file.read()
52+
notebook = nbformat.reads(notebook_content, as_version=4)
53+
54+
# Perform analysis on the uploaded notebook
55+
try:
56+
analysis_meta = headergen.start_headergen(
57+
file_location, OUTPUT_DIR, debug_mode=True
58+
)
59+
except Exception as e:
60+
logger.error(f"Analysis failed: {str(e)}")
61+
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
62+
63+
# Prepare the analysis output in a chunked dictionary, mapping analysis to cells
64+
analysis_output = {"cell_mapping": {}}
65+
66+
if "block_mapping" in analysis_meta:
67+
for cell_index, cell_results in analysis_meta["block_mapping"].items():
68+
# Get high-level phases and convert set to list
69+
ml_phases = list(set([DASWOW_PHASES.get(tag, "Unknown") for tag in cell_results["dl_pipeline_tag"]]))
70+
func_list = {k:{"doc_string":v, "arguments":[]} for k,v in cell_results.get("doc_string", {}).items()}
71+
72+
for call_args in cell_results["call_args"].values():
73+
for call, args in call_args.items():
74+
if call in func_list:
75+
func_list[call]["arguments"].append(args)
76+
77+
# Add to the chunked dictionary without modifying the content
78+
analysis_output["cell_mapping"][cell_index] = {
79+
"ml_phase": ml_phases, # Ensure ml_phases is a list, not a set
80+
"functions": func_list,
81+
}
82+
83+
# Return the chunked analysis output without overwriting notebook content
84+
return JSONResponse(content=analysis_output)
5885
except Exception as e:
59-
logger.error(f"Analysis failed: {str(e)}")
60-
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
61-
62-
# Prepare the analysis output in a chunked dictionary, mapping analysis to cells
63-
analysis_output = {"cell_mapping": {}}
64-
65-
if "block_mapping" in analysis_meta:
66-
for cell_index, cell_results in analysis_meta["block_mapping"].items():
67-
# Get high-level phases and convert set to list
68-
ml_phases = list(set(cell_results["dl_pipeline_tag"]))
69-
func_list = cell_results.get("doc_string", "")
70-
71-
# Add to the chunked dictionary without modifying the content
72-
analysis_output["cell_mapping"][cell_index] = {
73-
"ml_phase": ml_phases, # Ensure ml_phases is a list, not a set
74-
"functions": func_list,
75-
}
76-
77-
# Return the chunked analysis output without overwriting notebook content
78-
return JSONResponse(content=analysis_output)
86+
return JSONResponse(content={"error": str(e)}, status_code=500)
87+
7988

8089

8190
if __name__ == "__main__":

scripts/simple_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
script_dir = os.path.abspath(os.path.dirname(__file__))
1515

1616
# Careful, the out_path folder will be removed
17-
file_path = f"/mnt/Projects/PhD/Research/HeaderGen/git_sources/HeaderGen_github/.scrapy/test/test.py"
17+
file_path = f"/mnt/Projects/PhD/Research/HeaderGen/git_sources/headergen_githib/.scrapy/notebooks/01-keras-deep-learning-to-solve-titanic.ipynb"
1818
out_path = f"{script_dir}/results/"
1919

2020

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def package_files(directory):
3434

3535
setuptools.setup(
3636
name="headergen",
37-
version="2.0.0",
37+
version="2.0.1",
3838
description="HeaderGen: Automated cell header generator",
3939
long_description=open("README.md").read(),
4040
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)