diff --git a/app/user_manager.py b/app/user_manager.py index 62c22cde5ed..260c383b498 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -121,21 +121,38 @@ async def listuserdata(request): directory = request.rel_url.query.get('dir', '') if not directory: return web.Response(status=400) - + path = self.get_request_user_filepath(request, directory) if not path: return web.Response(status=403) - + if not os.path.exists(path): return web.Response(status=404) - + recurse = request.rel_url.query.get('recurse', '').lower() == "true" - results = glob.glob(os.path.join( - glob.escape(path), '**/*'), recursive=recurse) - results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)] - + full_info = request.rel_url.query.get('full_info', '').lower() == "true" + + # Use different patterns based on whether we're recursing or not + if recurse: + pattern = os.path.join(glob.escape(path), '**', '*') + else: + pattern = os.path.join(glob.escape(path), '*') + + results = glob.glob(pattern, recursive=recurse) + + if full_info: + results = [ + { + 'path': os.path.relpath(x, path), + 'size': os.path.getsize(x), + 'modified': os.path.getmtime(x) + } for x in results if os.path.isfile(x) + ] + else: + results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)] + split_path = request.rel_url.query.get('split', '').lower() == "true" - if split_path: + if split_path and not full_info: results = [[x] + x.split(os.sep) for x in results] return web.json_response(results) diff --git a/tests-unit/prompt_server_test/user_manager_test.py b/tests-unit/prompt_server_test/user_manager_test.py new file mode 100644 index 00000000000..c71050a2f4d --- /dev/null +++ b/tests-unit/prompt_server_test/user_manager_test.py @@ -0,0 +1,90 @@ +import pytest +import os +from aiohttp import web +from app.user_manager import UserManager + +pytestmark = ( + pytest.mark.asyncio +) # This applies the asyncio mark to all test functions in the module + + +@pytest.fixture +def user_manager(tmp_path): + um = UserManager() + um.get_request_user_filepath = lambda req, file, **kwargs: os.path.join( + tmp_path, file + ) + return um + + +@pytest.fixture +def app(user_manager): + app = web.Application() + routes = web.RouteTableDef() + user_manager.add_routes(routes) + app.add_routes(routes) + return app + + +async def test_listuserdata_empty_directory(aiohttp_client, app, tmp_path): + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir") + assert resp.status == 404 + + +async def test_listuserdata_with_files(aiohttp_client, app, tmp_path): + os.makedirs(tmp_path / "test_dir") + with open(tmp_path / "test_dir" / "file1.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir") + assert resp.status == 200 + assert await resp.json() == ["file1.txt"] + + +async def test_listuserdata_recursive(aiohttp_client, app, tmp_path): + os.makedirs(tmp_path / "test_dir" / "subdir") + with open(tmp_path / "test_dir" / "file1.txt", "w") as f: + f.write("test content") + with open(tmp_path / "test_dir" / "subdir" / "file2.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir&recurse=true") + assert resp.status == 200 + assert set(await resp.json()) == {"file1.txt", os.path.join("subdir", "file2.txt")} + + +async def test_listuserdata_full_info(aiohttp_client, app, tmp_path): + os.makedirs(tmp_path / "test_dir") + with open(tmp_path / "test_dir" / "file1.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir&full_info=true") + assert resp.status == 200 + result = await resp.json() + assert len(result) == 1 + assert result[0]["path"] == "file1.txt" + assert "size" in result[0] + assert "modified" in result[0] + + +async def test_listuserdata_split_path(aiohttp_client, app, tmp_path): + os.makedirs(tmp_path / "test_dir" / "subdir") + with open(tmp_path / "test_dir" / "subdir" / "file1.txt", "w") as f: + f.write("test content") + + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=test_dir&recurse=true&split=true") + assert resp.status == 200 + assert await resp.json() == [ + [os.path.join("subdir", "file1.txt"), "subdir", "file1.txt"] + ] + + +async def test_listuserdata_invalid_directory(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.get("/userdata?dir=") + assert resp.status == 400