|
8 | 8 | from fastapi.middleware.gzip import GZipMiddleware
|
9 | 9 | from fastapi.responses import JSONResponse
|
10 | 10 |
|
11 |
| -from framework_models import get_high_level_phase |
| 11 | +from framework_models import get_high_level_phase, DASWOW_PHASES |
12 | 12 | from headergen import headergen
|
13 | 13 |
|
14 | 14 | app = FastAPI()
|
|
38 | 38 | @app.post("/get_analysis_notebook/")
|
39 | 39 | async def get_analysis(file: UploadFile = File(...)):
|
40 | 40 | """Upload a notebook file, analyze chunks of it, and add metadata."""
|
41 |
| - # Save the uploaded file to the uploads directory |
42 |
| - file_location = f"{UPLOAD_DIR}/{file.filename}" |
43 |
| - |
44 |
| - async with aiofiles.open(file_location, "wb") as f: |
45 |
| - content = await file.read() |
46 |
| - await f.write(content) |
47 |
| - |
48 |
| - # Load the notebook |
49 |
| - async with aiofiles.open(file_location, "r", encoding="utf-8") as file: |
50 |
| - notebook_content = await file.read() |
51 |
| - notebook = nbformat.reads(notebook_content, as_version=4) |
52 |
| - |
53 |
| - # Perform analysis on the uploaded notebook |
54 | 41 | try:
|
55 |
| - analysis_meta = headergen.start_headergen( |
56 |
| - file_location, OUTPUT_DIR, debug_mode=True |
57 |
| - ) |
| 42 | + # Save the uploaded file to the uploads directory |
| 43 | + file_location = f"{UPLOAD_DIR}/{file.filename}" |
| 44 | + |
| 45 | + async with aiofiles.open(file_location, "wb") as f: |
| 46 | + content = await file.read() |
| 47 | + await f.write(content) |
| 48 | + |
| 49 | + # Load the notebook |
| 50 | + async with aiofiles.open(file_location, "r", encoding="utf-8") as file: |
| 51 | + notebook_content = await file.read() |
| 52 | + notebook = nbformat.reads(notebook_content, as_version=4) |
| 53 | + |
| 54 | + # Perform analysis on the uploaded notebook |
| 55 | + try: |
| 56 | + analysis_meta = headergen.start_headergen( |
| 57 | + file_location, OUTPUT_DIR, debug_mode=True |
| 58 | + ) |
| 59 | + except Exception as e: |
| 60 | + logger.error(f"Analysis failed: {str(e)}") |
| 61 | + raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") |
| 62 | + |
| 63 | + # Prepare the analysis output in a chunked dictionary, mapping analysis to cells |
| 64 | + analysis_output = {"cell_mapping": {}} |
| 65 | + |
| 66 | + if "block_mapping" in analysis_meta: |
| 67 | + for cell_index, cell_results in analysis_meta["block_mapping"].items(): |
| 68 | + # Get high-level phases and convert set to list |
| 69 | + ml_phases = list(set([DASWOW_PHASES.get(tag, "Unknown") for tag in cell_results["dl_pipeline_tag"]])) |
| 70 | + func_list = {k:{"doc_string":v, "arguments":[]} for k,v in cell_results.get("doc_string", {}).items()} |
| 71 | + |
| 72 | + for call_args in cell_results["call_args"].values(): |
| 73 | + for call, args in call_args.items(): |
| 74 | + if call in func_list: |
| 75 | + func_list[call]["arguments"].append(args) |
| 76 | + |
| 77 | + # Add to the chunked dictionary without modifying the content |
| 78 | + analysis_output["cell_mapping"][cell_index] = { |
| 79 | + "ml_phase": ml_phases, # Ensure ml_phases is a list, not a set |
| 80 | + "functions": func_list, |
| 81 | + } |
| 82 | + |
| 83 | + # Return the chunked analysis output without overwriting notebook content |
| 84 | + return JSONResponse(content=analysis_output) |
58 | 85 | except Exception as e:
|
59 |
| - logger.error(f"Analysis failed: {str(e)}") |
60 |
| - raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") |
61 |
| - |
62 |
| - # Prepare the analysis output in a chunked dictionary, mapping analysis to cells |
63 |
| - analysis_output = {"cell_mapping": {}} |
64 |
| - |
65 |
| - if "block_mapping" in analysis_meta: |
66 |
| - for cell_index, cell_results in analysis_meta["block_mapping"].items(): |
67 |
| - # Get high-level phases and convert set to list |
68 |
| - ml_phases = list(set(cell_results["dl_pipeline_tag"])) |
69 |
| - func_list = cell_results.get("doc_string", "") |
70 |
| - |
71 |
| - # Add to the chunked dictionary without modifying the content |
72 |
| - analysis_output["cell_mapping"][cell_index] = { |
73 |
| - "ml_phase": ml_phases, # Ensure ml_phases is a list, not a set |
74 |
| - "functions": func_list, |
75 |
| - } |
76 |
| - |
77 |
| - # Return the chunked analysis output without overwriting notebook content |
78 |
| - return JSONResponse(content=analysis_output) |
| 86 | + return JSONResponse(content={"error": str(e)}, status_code=500) |
| 87 | + |
79 | 88 |
|
80 | 89 |
|
81 | 90 | if __name__ == "__main__":
|
|
0 commit comments