From cfb5a146942ea6801ffb1037039ec7c1301276d4 Mon Sep 17 00:00:00 2001 From: Julian Martinsson Bonde Date: Mon, 12 Jun 2023 09:48:37 +0200 Subject: [PATCH] Version 1.0.4 (#87) * Bumped version to 1.0.4 * File access control (#86) * added subproject requirement to uploaded files * Adjusted file saving mechanism * Added comment for clarity * Implemented FileAccessControl based on subproject access * Improved robustness of MySQLStatementBuilder * Fixed static acces_level issue in FileAccessChecker * Mysql statement builder refactor (#88) * Removed mysqlutils and instead imported mysql-statement-builder package * Added mysql-statement-builder to requirements.txt * Reconnected mysqlsb logger with fastapi logger * Enable sorting user listings, and searching for users (#89) * Search for user functionality * Safe generic order-by for user lists * user sorting and search tests Added tests for new functionality * fixed failing cvs simulation test * Projects administration end-points for SED portal (#91) * Updated subprojects and projects data structures * Added min_length constraint to sub-project and project names * Also list project creation date and participant count * Project lists entries are now properly populated * Functionality to avoid accidental project overwrites * Fixed test fail and bumped up desim version (#90) * Fixed test that fails * desim-tool version 0.3.3 * replaced generic exception --------- Co-authored-by: EppChops Co-authored-by: Erik Berg <57296415+EppChops@users.noreply.github.com> * 83 Endpoint for deleting files (#95) * Functionality for deleting files * reset utils to previous state * 83 Update project details (#96) * Partial implementation of necessary methods * Post multiple participants at the same time * Fixed incorrect project access checker instantiation * Tests * Fixed broken test request for project update * Projects can now be updated * Improved log message --------- Co-authored-by: Oscar Bennet Co-authored-by: EppChops Co-authored-by: Erik Berg <57296415+EppChops@users.noreply.github.com> --- requirements.txt | 3 +- .../apps/core/authentication/storage.py | 2 +- sedbackend/apps/core/files/dependencies.py | 26 ++ sedbackend/apps/core/files/exceptions.py | 12 + sedbackend/apps/core/files/implementation.py | 22 ++ sedbackend/apps/core/files/models.py | 8 +- sedbackend/apps/core/files/router.py | 19 +- sedbackend/apps/core/files/storage.py | 74 ++++- sedbackend/apps/core/individuals/storage.py | 2 +- .../apps/core/measurements/implementation.py | 4 +- sedbackend/apps/core/measurements/router.py | 4 +- sedbackend/apps/core/measurements/storage.py | 2 +- sedbackend/apps/core/projects/dependencies.py | 14 +- sedbackend/apps/core/projects/exceptions.py | 4 + .../apps/core/projects/implementation.py | 81 ++++- sedbackend/apps/core/projects/models.py | 21 +- sedbackend/apps/core/projects/router.py | 30 +- sedbackend/apps/core/projects/storage.py | 278 +++++++++++++++--- sedbackend/apps/core/users/implementation.py | 51 +++- sedbackend/apps/core/users/router.py | 17 +- sedbackend/apps/core/users/storage.py | 50 +++- sedbackend/apps/cvs/design/storage.py | 2 +- sedbackend/apps/cvs/life_cycle/storage.py | 2 +- .../apps/cvs/link_design_lifecycle/storage.py | 2 +- sedbackend/apps/cvs/market_input/storage.py | 2 +- sedbackend/apps/cvs/project/storage.py | 2 +- sedbackend/apps/cvs/simulation/exceptions.py | 6 +- sedbackend/apps/cvs/simulation/storage.py | 2 +- sedbackend/apps/cvs/vcs/storage.py | 2 +- sedbackend/libs/mysqlutils/__init__.py | 3 - sedbackend/libs/mysqlutils/builder.py | 197 ------------- sedbackend/libs/mysqlutils/statements.py | 79 ----- sedbackend/libs/mysqlutils/utils.py | 19 -- sedbackend/main.py | 2 +- sedbackend/setup.py | 4 + sql/V230522_release_1_0_4.sql | 30 ++ tests/apps/core/files/test_files.py | 142 +++++++++ tests/apps/core/files/testutils.py | 8 + tests/apps/core/projects/test_projects.py | 108 ++++++- tests/apps/core/users/test_users.py | 92 +++++- tests/apps/core/users/testutils.py | 11 +- .../simulation/test_sim_multiprocessing.py | 8 +- tests/apps/cvs/simulation/test_simulation.py | 2 +- tests/apps/cvs/simulation/testutils.py | 2 +- tests/apps/cvs/testutils.py | 9 +- 45 files changed, 1045 insertions(+), 415 deletions(-) create mode 100644 sedbackend/apps/core/files/dependencies.py delete mode 100644 sedbackend/libs/mysqlutils/__init__.py delete mode 100644 sedbackend/libs/mysqlutils/builder.py delete mode 100644 sedbackend/libs/mysqlutils/statements.py delete mode 100644 sedbackend/libs/mysqlutils/utils.py create mode 100644 sql/V230522_release_1_0_4.sql create mode 100644 tests/apps/core/files/test_files.py create mode 100644 tests/apps/core/files/testutils.py diff --git a/requirements.txt b/requirements.txt index 70b5087d..35a83637 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ bcrypt==4.0.1 -desim-tool==0.3.1 +desim-tool==0.3.3 fastapi==0.95.1 mvmlib==0.5.9 mysql-connector-python==8.0.33 @@ -11,6 +11,7 @@ python-multipart==0.0.6 starlette==0.26.1 uvicorn==0.21.1 openpyxl==3.1.2 +mysql-statement-builder==0.* pytest==7.3.1 httpx==0.24.0 \ No newline at end of file diff --git a/sedbackend/apps/core/authentication/storage.py b/sedbackend/apps/core/authentication/storage.py index 4da6d0ac..57888662 100644 --- a/sedbackend/apps/core/authentication/storage.py +++ b/sedbackend/apps/core/authentication/storage.py @@ -6,7 +6,7 @@ from sedbackend.apps.core.authentication.models import UserAuth, SSOResolutionData from sedbackend.apps.core.users.exceptions import UserNotFoundException from sedbackend.apps.core.authentication.exceptions import InvalidNonceException, FaultyNonceOperation -from sedbackend.libs.mysqlutils.builder import MySQLStatementBuilder, FetchType +from mysqlsb.builder import MySQLStatementBuilder, FetchType from mysql.connector.pooling import PooledMySQLConnection diff --git a/sedbackend/apps/core/files/dependencies.py b/sedbackend/apps/core/files/dependencies.py new file mode 100644 index 00000000..2642c5d7 --- /dev/null +++ b/sedbackend/apps/core/files/dependencies.py @@ -0,0 +1,26 @@ +from typing import List + +from fastapi import Request +from fastapi.logger import logger + +from sedbackend.apps.core.projects.dependencies import SubProjectAccessChecker +from sedbackend.apps.core.projects.models import AccessLevel +from sedbackend.apps.core.projects.implementation import impl_get_subproject_by_id +from sedbackend.apps.core.files.implementation import impl_get_file_mapped_subproject_id + + +class FileAccessChecker: + def __init__(self, allowed_levels: List[AccessLevel]): + self.access_levels = allowed_levels + + def __call__(self, file_id: int, request: Request): + logger.debug(f'Is user with id {request.state.user_id} ' + f'allowed to access file with id {file_id}?') + user_id = request.state.user_id + + # Get subproject ID + subproject_id = impl_get_file_mapped_subproject_id(file_id) + + # Run subproject access check + subproject = impl_get_subproject_by_id(subproject_id) + return SubProjectAccessChecker.check_user_subproject_access(subproject, self.access_levels, user_id) diff --git a/sedbackend/apps/core/files/exceptions.py b/sedbackend/apps/core/files/exceptions.py index 0aa2f6a1..ef305bc1 100644 --- a/sedbackend/apps/core/files/exceptions.py +++ b/sedbackend/apps/core/files/exceptions.py @@ -8,3 +8,15 @@ class FileNotFoundException(Exception): class FileParsingException(Exception): pass + + +class SubprojectMappingNotFound(Exception): + pass + + +class FileNotDeletedException(Exception): + pass + + +class PathMismatchException(Exception): + pass diff --git a/sedbackend/apps/core/files/implementation.py b/sedbackend/apps/core/files/implementation.py index 62a18aa5..ba1f2f0f 100644 --- a/sedbackend/apps/core/files/implementation.py +++ b/sedbackend/apps/core/files/implementation.py @@ -39,6 +39,16 @@ def impl_delete_file(file_id: int, current_user_id: int) -> bool: status_code=status.HTTP_403_FORBIDDEN, detail=f"User does not have access to a file with id = {file_id}" ) + except exc.FileNotDeletedException: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"File could not be deleted" + ) + except exc.PathMismatchException: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f'Path to file does not match internal path' + ) def impl_get_file_path(file_id: int, current_user_id: int) -> models.StoredFilePath: @@ -102,3 +112,15 @@ def impl_get_file(file_id: int, current_user_id: int): status_code=status.HTTP_403_FORBIDDEN, detail="User does not have access to requested file." ) + + +def impl_get_file_mapped_subproject_id(file_id): + try: + with get_connection() as con: + subproject_id = storage.db_get_file_mapped_subproject_id(con, file_id) + return subproject_id + except exc.SubprojectMappingNotFound: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"No subproject mapping found for file with id = {file_id}" + ) diff --git a/sedbackend/apps/core/files/models.py b/sedbackend/apps/core/files/models.py index 4c2cd9fb..88536d2c 100644 --- a/sedbackend/apps/core/files/models.py +++ b/sedbackend/apps/core/files/models.py @@ -1,7 +1,6 @@ from typing import Any from datetime import datetime import os -from tempfile import SpooledTemporaryFile from pydantic import BaseModel from fastapi.datastructures import UploadFile @@ -12,15 +11,17 @@ class StoredFilePost(BaseModel): owner_id: int extension: str file_object: Any + subproject_id: int @staticmethod - def import_fastapi_file(file: UploadFile, current_user_id: int): + def import_fastapi_file(file: UploadFile, current_user_id: int, subproject_id: int): filename = file.filename extension = os.path.splitext(file.filename)[1] return StoredFilePost(filename=filename, extension=extension, owner_id=current_user_id, - file_object=file.file) + file_object=file.file, + subproject_id=subproject_id) class StoredFileEntry(BaseModel): @@ -30,6 +31,7 @@ class StoredFileEntry(BaseModel): insert_timestamp: datetime owner_id: int extension: str + subproject_id: int class StoredFile(BaseModel): diff --git a/sedbackend/apps/core/files/router.py b/sedbackend/apps/core/files/router.py index 68910132..7c4479d9 100644 --- a/sedbackend/apps/core/files/router.py +++ b/sedbackend/apps/core/files/router.py @@ -2,7 +2,9 @@ from fastapi.responses import FileResponse import sedbackend.apps.core.files.implementation as impl +from sedbackend.apps.core.files.dependencies import FileAccessChecker from sedbackend.apps.core.authentication.utils import get_current_active_user +from sedbackend.apps.core.projects.models import AccessLevel from sedbackend.apps.core.users.models import User @@ -11,7 +13,9 @@ @router.get("/{file_id}/download", summary="Download file", - response_class=FileResponse) + response_class=FileResponse, + dependencies=[Depends(FileAccessChecker(AccessLevel.list_can_read()))] + ) async def get_file(file_id: int, current_user: User = Depends(get_current_active_user)): """ Download an uploaded file @@ -22,3 +26,16 @@ async def get_file(file_id: int, current_user: User = Depends(get_current_active filename=stored_file_path.filename ) return resp + + +@router.delete("/{file_id}/delete", + summary="Delete file", + response_model=bool, + dependencies=[Depends(FileAccessChecker(AccessLevel.list_are_admins()))]) +async def delete_file(file_id: int, current_user: User = Depends(get_current_active_user)): + """ + Delete a file. + Only accessible to admins and the owner of the file. + """ + return impl.impl_delete_file(file_id, current_user.id) + \ No newline at end of file diff --git a/sedbackend/apps/core/files/storage.py b/sedbackend/apps/core/files/storage.py index 75e3e873..3e44d957 100644 --- a/sedbackend/apps/core/files/storage.py +++ b/sedbackend/apps/core/files/storage.py @@ -3,14 +3,19 @@ import os from mysql.connector.pooling import PooledMySQLConnection +from fastapi.logger import logger +import os import sedbackend.apps.core.files.models as models import sedbackend.apps.core.files.exceptions as exc -from sedbackend.libs.mysqlutils import MySQLStatementBuilder, exclude_cols, FetchType +import sedbackend.apps.core.files.implementation as impl +from mysqlsb import MySQLStatementBuilder, exclude_cols, FetchType FILES_RELATIVE_UPLOAD_DIR = f'{os.path.abspath(os.sep)}sed_lab/uploaded_files/' FILES_TABLE = 'files' +FILES_TO_SUBPROJECTS_MAP_TABLE = 'files_subprojects_map' FILES_COLUMNS = ['id', 'temp', 'uuid', 'filename', 'insert_timestamp', 'directory', 'owner_id', 'extension'] +FILES_TO_SUBPROJECTS_MAP_COLUMNS = ['id', 'file_id', 'subproject_id'] def db_save_file(con: PooledMySQLConnection, file: models.StoredFilePost) -> models.StoredFileEntry: @@ -28,30 +33,66 @@ def db_save_file(con: PooledMySQLConnection, file: models.StoredFilePost) -> mod file_id = insert_stmnt.last_insert_id + # Store mapping between file id and subproject id in database + insert_mapping_stmnt = MySQLStatementBuilder(con) + insert_mapping_stmnt.insert(FILES_TO_SUBPROJECTS_MAP_TABLE, ['file_id', 'subproject_id'])\ + .set_values([file_id, file.subproject_id])\ + .execute() + return db_get_file_entry(con, file_id, file.owner_id) def db_delete_file(con: PooledMySQLConnection, file_id: int, current_user_id: int) -> bool: + stored_file_path = impl.impl_get_file_path(file_id, current_user_id) + + if os.path.commonpath([FILES_RELATIVE_UPLOAD_DIR]) != os.path.commonpath([FILES_RELATIVE_UPLOAD_DIR, os.path.abspath(stored_file_path.path)]): + raise exc.PathMismatchException + + try: + os.remove(stored_file_path.path) + delete_stmnt = MySQLStatementBuilder(con) + delete_stmnt.delete(FILES_TABLE) \ + .where('id=?', [file_id]) \ + .execute(fetch_type=FetchType.FETCH_NONE) + + except Exception: + raise exc.FileNotDeletedException + return True def db_get_file_entry(con: PooledMySQLConnection, file_id: int, current_user_id: int) -> models.StoredFileEntry: - select_stmnt = MySQLStatementBuilder(con) - res = select_stmnt.select(FILES_TABLE, exclude_cols(FILES_COLUMNS, ['uuid', 'directory']))\ - .where('id = ?', [file_id])\ - .execute(dictionary=True, fetch_type=FetchType.FETCH_ONE) + res_dict = None + with con.cursor(prepared=True) as cursor: + # This expression uses two tables (files and files_to_subprojects_map) + query = f"SELECT {', '.join(['f.id', 'f.temp', 'f.uuid', 'f.filename', 'f.insert_timestamp', 'f.directory', 'f.owner_id', 'f.extension'])}, fsm.`subproject_id` " \ + f"FROM `{FILES_TABLE}` f " \ + f"INNER JOIN {FILES_TO_SUBPROJECTS_MAP_TABLE} fsm ON (f.id = fsm.file_id) " \ + f"WHERE f.`id` = ?" + values = [file_id] - if res is None: - raise exc.FileNotFoundException + # Log for sanity-check + logger.debug(f"db_get_file_entry query: '{query}' with values: {values}") + + # Execute query + cursor.execute(query, values) + + # Handle results + results = cursor.fetchone() - stored_file = models.StoredFileEntry(**res) + if results is None: + raise exc.FileNotFoundException + + res_dict = dict(zip(cursor.column_names, results)) + + stored_file = models.StoredFileEntry(**res_dict) return stored_file def db_get_file_path(con: PooledMySQLConnection, file_id: int, current_user_id: int) -> models.StoredFilePath: select_stmnt = MySQLStatementBuilder(con) res = select_stmnt\ - .select(FILES_TABLE, ['filename', 'uuid', 'directory', 'extension'])\ + .select(FILES_TABLE, ['filename', 'uuid', 'directory', 'owner_id', 'extension'])\ .where('id=?', [file_id])\ .execute(dictionary=True, fetch_type=FetchType.FETCH_ONE) @@ -59,7 +100,8 @@ def db_get_file_path(con: PooledMySQLConnection, file_id: int, current_user_id: raise exc.FileNotFoundException('File not found in DB') path = res['directory'] + res['uuid'] - stored_path = models.StoredFilePath(id=file_id, filename=res['filename'], path=path, extension=res['extension']) + stored_path = models.StoredFilePath( + id=file_id, filename=res['filename'], path=path, owner_id=res['owner_id'], extension=res['extension']) return stored_path @@ -71,3 +113,15 @@ def db_put_file_temp(con: PooledMySQLConnection, file_id: int, temp: bool, curre def db_put_filename(con: PooledMySQLConnection, file_id: int, filename_new: str, current_user_id: int) \ -> models.StoredFileEntry: pass + + +def db_get_file_mapped_subproject_id(con: PooledMySQLConnection, file_id) -> int: + select_stmnt = MySQLStatementBuilder(con) + res = select_stmnt.select(FILES_TO_SUBPROJECTS_MAP_TABLE, ['subproject_id'])\ + .where('file_id=?', [file_id])\ + .execute(dictionary=True, fetch_type=FetchType.FETCH_ONE) + + if res is None: + raise exc.SubprojectMappingNotFound('Mapping could not be found.') + + return res['subproject_id'] diff --git a/sedbackend/apps/core/individuals/storage.py b/sedbackend/apps/core/individuals/storage.py index 3c51f183..57d5f81b 100644 --- a/sedbackend/apps/core/individuals/storage.py +++ b/sedbackend/apps/core/individuals/storage.py @@ -5,7 +5,7 @@ import sedbackend.apps.core.individuals.models as models import sedbackend.apps.core.individuals.exceptions as ex -from sedbackend.libs.mysqlutils import MySQLStatementBuilder, FetchType, exclude_cols +from mysqlsb import MySQLStatementBuilder, FetchType, exclude_cols INDIVIDUALS_TABLE = 'individuals' INDIVIDUALS_COLUMNS = ['id', 'name', 'is_archetype'] diff --git a/sedbackend/apps/core/measurements/implementation.py b/sedbackend/apps/core/measurements/implementation.py index 281d5227..d6ec4dc5 100644 --- a/sedbackend/apps/core/measurements/implementation.py +++ b/sedbackend/apps/core/measurements/implementation.py @@ -119,9 +119,9 @@ def impl_post_measurement_result(measurement_id: int, mr: models.MeasurementResu return res -def impl_post_upload_set(file, current_user_id: int, csv_delimiter: Optional[str] = None) -> List: +def impl_post_upload_set(file, current_user_id: int, subproject_id: int, csv_delimiter: Optional[str] = None) -> List: try: - stored_file_post = models_files.StoredFilePost.import_fastapi_file(file, current_user_id) + stored_file_post = models_files.StoredFilePost.import_fastapi_file(file, current_user_id, subproject_id) with get_connection() as con: file_entry = storage_files.db_save_file(con, stored_file_post) file_path = storage_files.db_get_file_path(con, file_entry.id, current_user_id) diff --git a/sedbackend/apps/core/measurements/router.py b/sedbackend/apps/core/measurements/router.py index 99b42c32..cfdf6b47 100644 --- a/sedbackend/apps/core/measurements/router.py +++ b/sedbackend/apps/core/measurements/router.py @@ -30,9 +30,9 @@ async def get_measurement_sets(subproject_id: Optional[int] = None): response_model=List[str], description="Upload a measurement set using a CSV or Excel file. Leaving csv_delimiter as None will " "result in the value being inferred automatically.") -async def post_upload_set(file: UploadFile = File(...), current_user: User = Depends(get_current_active_user), +async def post_upload_set(subproject_id: int, file: UploadFile = File(...), current_user: User = Depends(get_current_active_user), csv_delimiter: Optional[str] = None): - return impl.impl_post_upload_set(file, current_user.id, csv_delimiter=csv_delimiter) + return impl.impl_post_upload_set(file, current_user.id, subproject_id, csv_delimiter=csv_delimiter) @router.get("/sets/{measurement_set_id}", diff --git a/sedbackend/apps/core/measurements/storage.py b/sedbackend/apps/core/measurements/storage.py index 014dfcbe..2b208afd 100644 --- a/sedbackend/apps/core/measurements/storage.py +++ b/sedbackend/apps/core/measurements/storage.py @@ -3,7 +3,7 @@ from fastapi.logger import logger -from sedbackend.libs.mysqlutils import MySQLStatementBuilder, FetchType, exclude_cols +from mysqlsb import MySQLStatementBuilder, FetchType, exclude_cols import sedbackend.apps.core.measurements.models as models import sedbackend.apps.core.measurements.exceptions as exc diff --git a/sedbackend/apps/core/projects/dependencies.py b/sedbackend/apps/core/projects/dependencies.py index f6936dfd..8043ee77 100644 --- a/sedbackend/apps/core/projects/dependencies.py +++ b/sedbackend/apps/core/projects/dependencies.py @@ -3,7 +3,7 @@ from fastapi import HTTPException, Request, status from fastapi.logger import logger -from sedbackend.apps.core.projects.models import AccessLevel +from sedbackend.apps.core.projects.models import AccessLevel, SubProject from sedbackend.apps.core.projects.implementation import impl_get_project, impl_get_subproject_native @@ -53,18 +53,23 @@ def __call__(self, native_project_id: int, request: Request): # Get subproject subproject = impl_get_subproject_native(self.application_sid, native_project_id) + return SubProjectAccessChecker.check_user_subproject_access(subproject, self.access_levels, user_id) + + @staticmethod + def check_user_subproject_access(subproject: SubProject, access_levels: List[AccessLevel], user_id: int): if subproject.project_id is not None: # Get project project = impl_get_project(subproject.project_id) # <-- This can throw # Check user access level in that project access = project.participants_access[user_id] - if access in self.access_levels: + if access in access_levels: logger.debug(f"Yes, user {user_id} has access level {access}") return True else: # Fallback solution: Check if user is the owner/creator of the subproject. - if request.state.user_id == subproject.owner_id: - logger.debug("User is owner of subproject.") + if user_id == subproject.owner_id: + logger.debug(f"User with id {user_id} is the owner of subproject with id {subproject.id} " + f"(owner_id = {subproject.owner_id}).") return True logger.debug(f"No, user {user_id} does not have the minimum required access level") @@ -73,3 +78,4 @@ def __call__(self, native_project_id: int, request: Request): detail="User does not have the necessary access level", ) + diff --git a/sedbackend/apps/core/projects/exceptions.py b/sedbackend/apps/core/projects/exceptions.py index a45876ce..78899fb9 100644 --- a/sedbackend/apps/core/projects/exceptions.py +++ b/sedbackend/apps/core/projects/exceptions.py @@ -36,3 +36,7 @@ class SubProjectDuplicateException(Exception): class ParticipantInconsistencyException(Exception): pass + + +class ConflictingProjectAssociationException(Exception): + pass diff --git a/sedbackend/apps/core/projects/implementation.py b/sedbackend/apps/core/projects/implementation.py index fe170af5..84c679c2 100644 --- a/sedbackend/apps/core/projects/implementation.py +++ b/sedbackend/apps/core/projects/implementation.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Union from fastapi import HTTPException, status @@ -11,9 +11,9 @@ import sedbackend.apps.core.projects.exceptions as exc -def impl_get_projects(segment_length: int = None, index: int = None): +def impl_get_projects(user_id: int, segment_length: int = 0, index: int = 0): with get_connection() as con: - return storage.db_get_projects(con, segment_length, index) + return storage.db_get_projects(con, user_id, segment_length, index) def impl_get_user_projects(user_id: int, segment_length: int = 0, index: int = 0) -> List[models.ProjectListing]: @@ -52,6 +52,11 @@ def impl_post_project(project: models.ProjectPost, owner_id: int) -> models.Proj status_code=status.HTTP_400_BAD_REQUEST, detail=str(e) ) + except exc.ConflictingProjectAssociationException as e: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=str(e) + ) def impl_delete_project(project_id: int) -> bool: @@ -67,6 +72,26 @@ def impl_delete_project(project_id: int) -> bool: ) +def impl_update_project(project_id: int, project_updated: models.ProjectEdit) -> models.Project: + # Validate input + if project_id != project_updated.id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Conflicting project IDs (payload vs URL)" + ) + + try: + with get_connection() as con: + res = storage.db_update_project(con, project_updated) + con.commit() + return res + except exc.ProjectNotFoundException: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Project not found" + ) + + def impl_post_participant(project_id: int, user_id: int, access_level: models.AccessLevel) -> bool: try: with get_connection() as con: @@ -80,6 +105,19 @@ def impl_post_participant(project_id: int, user_id: int, access_level: models.Ac ) +def impl_post_participants(project_id: int, participants_access_dict: dict[int, models.AccessLevel]) -> bool: + try: + with get_connection() as con: + res = storage.db_add_participants(con, project_id, participants_access_dict) + con.commit() + return res + except exc.ProjectNotFoundException: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Project not found", + ) + + def impl_delete_participant(project_id: int, user_id: int) -> bool: try: with get_connection() as con: @@ -158,16 +196,27 @@ def impl_get_subproject_native(application_sid: str, native_project_id: int) -> except exc.SubProjectNotFoundException: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Sub project not found" + detail="Sub-project not found." ) except ApplicationNotFoundException: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="No such application" + detail="No such application." ) -def impl_delete_subproject(project_id: int, subproject_id: int) -> bool: +def impl_get_subproject_by_id(subproject_id: int) -> models.SubProject: + try: + with get_connection() as con: + return storage.db_get_subproject_with_id(con, subproject_id) + except exc.SubProjectNotFoundException: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Sub-project not found." + ) + + +def impl_delete_subproject(project_id: Union[int, None], subproject_id: int) -> bool: try: with get_connection() as con: res = storage.db_delete_subproject(con, project_id, subproject_id) @@ -196,3 +245,23 @@ def impl_delete_subproject_native(application_id: str, native_project_id: int): status_code=status.HTTP_404_NOT_FOUND, detail="No such subproject" ) + + +def impl_get_user_subprojects_with_application_sid(current_user_id: int, user_id: int, application_id: str, + no_project_association: bool = False): + # This may look redundant, but it is there to prevent devs from accidentally giving access to any user. + if current_user_id != user_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User does not have access to this information" + ) + + try: + with get_connection() as con: + return storage.db_get_user_subprojects_with_application_sid(con, user_id, application_id, + no_project_association=no_project_association) + except ApplicationNotFoundException: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Application with ID = {application_id} is not available." + ) diff --git a/sedbackend/apps/core/projects/models.py b/sedbackend/apps/core/projects/models.py index 8d34d082..2e01049d 100644 --- a/sedbackend/apps/core/projects/models.py +++ b/sedbackend/apps/core/projects/models.py @@ -1,8 +1,9 @@ from __future__ import annotations # Obsolete in Python 3.10 from typing import Optional, List, Dict from enum import IntEnum, unique +from datetime import datetime -from pydantic import BaseModel +from pydantic import BaseModel, constr from sedbackend.apps.core.users.models import User @@ -32,23 +33,39 @@ class ProjectListing(BaseModel): id: int name: str access_level: AccessLevel = AccessLevel.NONE + participants: int = 0 + datetime_created: datetime class ProjectPost(BaseModel): - name: str + name: constr(min_length=5) participants: List[int] participants_access: Dict[int, AccessLevel] + subprojects: Optional[List[int]] = [] # List of sub-project IDs (not native) + + +class ProjectEdit(BaseModel): + id: int + name: Optional[constr(min_length=5)] = None + participants_to_add: Optional[Dict[int, AccessLevel]] = {} + participants_to_remove: Optional[List[int]] = [] + subprojects_to_add: Optional[List[int]] = [] + subprojects_to_remove: Optional[List[int]] = [] class SubProjectPost(BaseModel): + name: Optional[constr(min_length=5)] = 'Unnamed sub-project' application_sid: str native_project_id: int class SubProject(SubProjectPost): id: int + name: str = None owner_id: int project_id: Optional[int] + native_project_id: int + datetime_created: datetime class Project(BaseModel): diff --git a/sedbackend/apps/core/projects/router.py b/sedbackend/apps/core/projects/router.py index 9b77db62..16c0a2b6 100644 --- a/sedbackend/apps/core/projects/router.py +++ b/sedbackend/apps/core/projects/router.py @@ -13,12 +13,14 @@ @router.get("", - summary="Lists all projects", + summary="Lists all accessible projects", description="Lists all projects in alphabetical order", response_model=List[models.ProjectListing]) -async def get_projects(segment_length: int = None, index: int = None, current_user: User = Depends(get_current_active_user)): +async def get_projects(segment_length: Optional[int] = 0, index: Optional[int] = 0, + current_user: User = Depends(get_current_active_user)): """ Lists all projects that the current user has access to + :param current_user: :param segment_length: :param index: :return: @@ -31,14 +33,16 @@ async def get_projects(segment_length: int = None, index: int = None, current_us description="Lists all projects that exist, and is only available to those who have the authority.", response_model=List[models.ProjectListing], dependencies=[Security(verify_scopes, scopes=['admin'])]) -async def get_all_projects(segment_length: Optional[int] = None, index: Optional[int] = None): +async def get_all_projects(segment_length: Optional[int] = 0, index: Optional[int] = 0, + current_user: User = Depends(get_current_active_user)): """ Lists all projects that exists, and is only available to those who have the authority. + :param current_user: :param segment_length: :param index: :return: """ - return impl.impl_get_projects(segment_length, index) + return impl.impl_get_projects(current_user.id, segment_length, index) @router.get("/{project_id}", @@ -69,6 +73,15 @@ async def delete_project(project_id: int): return impl.impl_delete_project(project_id) +@router.put("/{project_id}", + summary="Edit project", + description="Edit project", + response_model=models.Project, + dependencies=[Depends(ProjectAccessChecker(models.AccessLevel.list_are_admins()))]) +async def update_project(project_id: int, project_updated: models.ProjectEdit): + return impl.impl_update_project(project_id, project_updated) + + @router.post("/{project_id}/participants", summary="Add participant to project", description="Add a participant to a project", @@ -135,3 +148,12 @@ async def delete_subproject(project_id: int, subproject_id: int): response_model=models.SubProject) async def get_app_native_project(app_id, native_project_id): return impl.impl_get_subproject_native(app_id, native_project_id) + + +@router.get("/apps/{app_id}/native-subprojects", + summary="List application specific native subprojects available to the user", + response_model=List[models.SubProject]) +async def get_user_subprojects_with_application_sid(app_id: str, current_user: User = Depends(get_current_active_user), + no_project_association: Optional[bool] = False): + return impl.impl_get_user_subprojects_with_application_sid(current_user.id, current_user.id, app_id, + no_project_association = no_project_association) diff --git a/sedbackend/apps/core/projects/storage.py b/sedbackend/apps/core/projects/storage.py index be4b712f..fddeea1b 100644 --- a/sedbackend/apps/core/projects/storage.py +++ b/sedbackend/apps/core/projects/storage.py @@ -1,9 +1,10 @@ -from typing import List, Optional +from typing import List, Optional, Dict from fastapi.logger import logger -from sedbackend.libs.mysqlutils import MySQLStatementBuilder, FetchType +from mysqlsb import MySQLStatementBuilder, FetchType from mysql.connector.pooling import PooledMySQLConnection +from sedbackend.apps.core.exceptions import NoChangeException from sedbackend.apps.core.applications.state import get_application import sedbackend.apps.core.projects.models as models import sedbackend.apps.core.projects.exceptions as exc @@ -12,52 +13,97 @@ PROJECTS_TABLE = 'projects' PROJECTS_COLUMNS = ['id', 'name'] SUBPROJECTS_TABLE = 'projects_subprojects' -SUBPROJECT_COLUMNS = ['id', 'application_sid', 'project_id', 'native_project_id', 'owner_id'] +SUBPROJECT_COLUMNS = ['id', 'name', 'application_sid', 'project_id', 'native_project_id', 'owner_id', + 'datetime_created'] PROJECTS_PARTICIPANTS_TABLE = 'projects_participants' PROJECTS_PARTICIPANTS_COLUMNS = ['id', 'user_id', 'project_id', 'access_level'] -def db_get_projects(connection, segment_length: int = None, index: int = None) -> List[models.ProjectListing]: - mysql_statement = MySQLStatementBuilder(connection) - stmnt = mysql_statement \ - .select('projects', PROJECTS_COLUMNS) +def db_get_projects(connection, user_id: int, segment_length: int = 0, index: int = 0) -> List[models.ProjectListing]: - if segment_length is not None: - stmnt = stmnt.limit(segment_length) - if index is not None: - stmnt = stmnt.offset(segment_length * index) + if index < 0: + index = 0 - rs = stmnt.execute(fetch_type=FetchType.FETCH_ALL, dictionary=True) - projects = [] - for res in rs: - projects.append(models.ProjectListing(**res)) + if segment_length < 0: + segment_length = 0 + + with connection.cursor(prepared=True) as cursor: + select_stmnt = 'SELECT projects.name, projects.id as pid, ' \ + 'projects.datetime_created, ' \ + '(SELECT count(*) as participant_count FROM projects_participants WHERE project_id = pid), ' \ + '(SELECT access_level FROM projects_participants WHERE project_id = pid AND user_id = %s) ' \ + 'FROM projects ' \ + f'ORDER BY `projects`.`datetime_created` ASC ' \ + f'LIMIT {segment_length} OFFSET {segment_length * index} ' + + values = [user_id] + logger.debug(f'db_get_projects: {select_stmnt} with values {values}') + cursor.execute(select_stmnt, values) + rs = cursor.fetchall() + + project_list = [] + for res in rs: + res_dict = dict(zip(['name', 'pid', 'datetime_created', 'participant_count', 'access_level'], res)) + + access_level = res_dict["access_level"] + if access_level is None: + access_level = 0 + + pl = models.ProjectListing(id=res_dict['pid'], name=res_dict['name'], + access_level=models.AccessLevel(access_level), + participants=res_dict["participant_count"], + datetime_created=res_dict['datetime_created']) + project_list.append(pl) - return projects + return project_list -def db_get_user_projects(connection, user_id: int, segment_length: int = 0, index: int = 0) -> List[ - models.ProjectListing]: - participating_sql = MySQLStatementBuilder(connection) +def db_get_user_projects(connection, user_id: int, segment_length: int = 0, index: int = 0) \ + -> List[models.ProjectListing]: if index < 0: index = 0 - sql = participating_sql\ - .select(PROJECTS_TABLE, ['projects_participants.access_level', 'projects.name', 'projects.id']) \ - .inner_join(PROJECTS_PARTICIPANTS_TABLE, 'projects_participants.project_id = projects.id')\ - .where('projects_participants.user_id = %s', [user_id]) - if segment_length > 0: - # Segment if segment length is specified - sql = sql.limit(segment_length).offset(segment_length * index) + if segment_length < 0: + segment_length = 0 + + with connection.cursor(prepared=True) as cursor: + select_stmnt = 'SELECT projects_participants.access_level, projects.name, projects.id as pid, ' \ + 'projects.datetime_created, ' \ + '(SELECT count(*) as participant_count FROM projects_participants WHERE project_id = pid) ' \ + 'FROM projects ' \ + 'INNER JOIN projects_participants ON projects_participants.project_id = projects.id ' \ + f'WHERE projects_participants.user_id = %s ' \ + f'ORDER BY `projects`.`datetime_created` ASC ' + if segment_length != 0: + select_stmnt += f'LIMIT {segment_length} OFFSET {segment_length * index} ' \ + + + values = [user_id] + logger.debug(f'db_get_user_projects: {select_stmnt} with values {values}') + cursor.execute(select_stmnt, values) + rs = cursor.fetchall() + + project_list = [] + for res in rs: + res_dict = dict(zip(['access_level', 'name', 'pid', 'datetime_created', 'participant_count'], res)) + pl = models.ProjectListing(id=res_dict['pid'], name=res_dict['name'], + access_level=models.AccessLevel(res_dict["access_level"]), + participants=res_dict["participant_count"], + datetime_created=res_dict['datetime_created']) + project_list.append(pl) + + return project_list - rs = sql.execute(fetch_type=FetchType.FETCH_ALL, dictionary=True) - project_list = [] - for result in rs: - pl = models.ProjectListing(name=result['name'], access_level=result['access_level'], id=result['id']) - project_list.append(pl) +def db_get_participant_count (connection, project_id) -> int: + select_stmnt = MySQLStatementBuilder(connection) + res = select_stmnt\ + .count(PROJECTS_PARTICIPANTS_TABLE)\ + .where('project_id = %s', [project_id])\ + .execute(dictionary=True) - return project_list + return res["count"] def db_get_project(connection, project_id) -> models.Project: @@ -95,12 +141,10 @@ def db_get_project(connection, project_id) -> models.Project: def db_post_project(connection, project: models.ProjectPost, owner_id: int) -> models.Project: - logger.debug('Adding new project:') - logger.debug(project) - # Set owner if it is not already set if owner_id not in project.participants: project.participants.append(owner_id) + project.participants_access[owner_id] = models.AccessLevel.OWNER project_sql = MySQLStatementBuilder(connection) @@ -116,9 +160,48 @@ def db_post_project(connection, project: models.ProjectPost, owner_id: int) -> m db_add_participant(connection, project_id, participant_id, access_level, check_project_exists=False) + if len(project.subprojects) > 0: + db_update_subprojects_project_association(connection, project.subprojects, project_id) + return db_get_project(connection, project_id) +def db_clear_subproject_project_association(connection: PooledMySQLConnection, subproject_id_list: List[int], current_project_id: int): + logger.debug(f"Clearing subproject association with project with ID = {current_project_id} " + f"for subprojects with IDs {str(subproject_id_list)})") + + update_stmnt = MySQLStatementBuilder(connection) + update_stmnt.update(SUBPROJECTS_TABLE, "project_id = NULL", [])\ + .where(f'id IN {MySQLStatementBuilder.placeholder_array(len(subproject_id_list))}', subproject_id_list)\ + .execute() + + +def db_update_subprojects_project_association(connection: PooledMySQLConnection, subproject_id_list: List[int], + project_id: int, overwrite: bool = False): + logger.debug(f'Associating sub-projects with IDs {subproject_id_list} to project with ID {project_id}') + + if overwrite is False: + # Assert that these subprojects are not already members of other projects + select_stmnt = MySQLStatementBuilder(connection) + rs = select_stmnt.select(SUBPROJECTS_TABLE, ['project_id', 'id', 'name'])\ + .where(f'id IN {MySQLStatementBuilder.placeholder_array(len(subproject_id_list))}', subproject_id_list)\ + .execute(fetch_type=FetchType.FETCH_ALL, dictionary=True) + + for res in rs: + if res['project_id'] is not None: + raise exc.ConflictingProjectAssociationException(f'Subproject "{res["name"]}" (id: {res["id"]}) ' + f'is already associated with another project, ' + f'and overwrite has been disabled.') + + update_stmnt = MySQLStatementBuilder(connection) + update_stmnt\ + .update(SUBPROJECTS_TABLE, "project_id = %s", [project_id])\ + .where(f'id IN {MySQLStatementBuilder.placeholder_array(len(subproject_id_list))}', subproject_id_list)\ + .execute(fetch_type=FetchType.FETCH_NONE) + + return + + def db_delete_project(connection, project_id: int) -> bool: logger.debug(f"Removing project: {project_id}") @@ -145,6 +228,23 @@ def db_add_participant(connection, project_id, user_id, access_level, check_proj return True +def db_add_participants(connection: PooledMySQLConnection, project_id: int, + user_id_access_map: Dict[int, models.AccessLevel], check_project_exists=True) -> bool: + if check_project_exists: + db_get_project_exists(connection, project_id) # Raises exception if project does not exist + + insert_stmnt = MySQLStatementBuilder(connection) + insert_stmnt.insert(PROJECTS_PARTICIPANTS_TABLE, ['user_id', 'project_id', 'access_level']) + + insert_values = [] + for user_id, access_level in user_id_access_map.items(): + insert_values.append([user_id, project_id, access_level]) + + insert_stmnt.set_values(insert_values).execute() + + return True + + def db_delete_participant(connection, project_id, user_id, check_project_exists=True) -> bool: if check_project_exists: db_get_project_exists(connection, project_id) # Raises exception if project does not exist @@ -161,6 +261,28 @@ def db_delete_participant(connection, project_id, user_id, check_project_exists= return True +def db_delete_participants(connection: PooledMySQLConnection, project_id: int, user_ids: List[int], + check_project_exists=True) -> bool: + + logger.debug(f"Removing participants with ids = {user_ids} from project with id = {project_id}") + + if check_project_exists: + db_get_project_exists(connection, project_id) + + del_stmnt = MySQLStatementBuilder(connection) + values = [project_id] + values.extend(user_ids) + res, row_count = del_stmnt\ + .delete(PROJECTS_PARTICIPANTS_TABLE)\ + .where(f'project_id = %s AND user_id IN {MySQLStatementBuilder.placeholder_array(len(user_ids))}', + values).execute(return_affected_rows=True) + + if row_count != len(user_ids): + raise NoChangeException('Not all participants could be found') + + return True + + def db_put_name(connection, project_id, name) -> bool: project = db_get_project(connection, project_id) # Raises exception if project does not exist @@ -195,8 +317,8 @@ def db_post_subproject(connection, subproject: models.SubProjectPost, current_us except exc.SubProjectNotFoundException: insert_stmnt = MySQLStatementBuilder(connection) insert_stmnt\ - .insert(SUBPROJECTS_TABLE, ['application_sid', 'project_id', 'native_project_id', 'owner_id'])\ - .set_values([subproject.application_sid, project_id, subproject.native_project_id, current_user_id])\ + .insert(SUBPROJECTS_TABLE, ['name', 'application_sid', 'project_id', 'native_project_id', 'owner_id'])\ + .set_values([subproject.name, subproject.application_sid, project_id, subproject.native_project_id, current_user_id])\ .execute() return db_get_subproject_native(connection, subproject.application_sid, subproject.native_project_id) @@ -225,7 +347,6 @@ def db_get_subprojects(connection: PooledMySQLConnection, project_id: int) \ def db_get_subproject(connection, project_id, subproject_id) -> models.SubProject: - db_get_project_exists(connection, project_id) # Raises exception if project does not exist select_stmnt = MySQLStatementBuilder(connection) @@ -242,6 +363,21 @@ def db_get_subproject(connection, project_id, subproject_id) -> models.SubProjec return sub_project +def db_get_subproject_with_id(connection, subproject_id) -> models.SubProject: + select_stmnt = MySQLStatementBuilder(connection) + res = select_stmnt\ + .select(SUBPROJECTS_TABLE, SUBPROJECT_COLUMNS)\ + .where("id = %s", [subproject_id])\ + .execute(fetch_type=FetchType.FETCH_ONE, dictionary=True) + + if res is None: + raise exc.SubProjectNotFoundException + + sub_project = models.SubProject(**res) + + return sub_project + + def db_get_subproject_native(connection, application_sid, native_project_id) -> models.SubProject: get_application(application_sid) # Raises exception of application does not exist @@ -260,12 +396,15 @@ def db_get_subproject_native(connection, application_sid, native_project_id) -> def db_delete_subproject(connection, project_id, subproject_id) -> bool: - db_get_subproject(connection, project_id, subproject_id) # Raises exception if project does not exist - delete_stmnt = MySQLStatementBuilder(connection) - res, row_count = delete_stmnt.delete(SUBPROJECTS_TABLE)\ - .where("project_id = %s AND id = %s", [project_id, subproject_id])\ - .execute(return_affected_rows=True) + delete_stmnt.delete(SUBPROJECTS_TABLE) + + if project_id is not None: + delete_stmnt.where("project_id = %s AND id = %s", [project_id, subproject_id]) + else: + delete_stmnt.where("id = %s AND project_id IS NULL", [subproject_id]) + + res, row_count = delete_stmnt.execute(return_affected_rows=True) if row_count == 0: raise exc.SubProjectNotDeletedException @@ -273,7 +412,13 @@ def db_delete_subproject(connection, project_id, subproject_id) -> bool: return True -def db_get_user_subprojects_with_application_sid(con, user_id, application_sid) -> List[models.SubProject]: +def db_get_user_subprojects_with_application_sid(con, user_id, application_sid, + no_project_association: Optional[bool] = False) \ + -> List[models.SubProject]: + + # Validate that the application is listed + get_application(application_sid) + # Get projects in which this user is a participant project_list = db_get_user_projects(con, user_id) project_id_list = [] @@ -284,11 +429,16 @@ def db_get_user_subprojects_with_application_sid(con, user_id, application_sid) return [] # Figure out which of those projects have an attached subproject with specified application SID. - where_values = project_id_list.copy() - where_values.append(application_sid) + where_values = [application_sid] + where_values.extend(project_id_list.copy()) + where_values.append(user_id) + where_stmnt = f"application_sid = %s AND " \ + f"((project_id IN {MySQLStatementBuilder.placeholder_array(len(project_id_list))}) OR " \ + f"(project_id is null AND owner_id = %s))" + stmnt = MySQLStatementBuilder(con) rs = stmnt.select(SUBPROJECTS_TABLE, SUBPROJECT_COLUMNS)\ - .where(f"project_id IN {MySQLStatementBuilder.placeholder_array(len(project_id_list))} AND application_sid = %s", + .where(where_stmnt, where_values)\ .execute(fetch_type=FetchType.FETCH_ALL, dictionary=True) @@ -297,9 +447,43 @@ def db_get_user_subprojects_with_application_sid(con, user_id, application_sid) for res in rs: subproject_list.append(models.SubProject(**res)) + if no_project_association: + subproject_list = list(filter(lambda p: p.project_id is None, subproject_list)) + return subproject_list +def db_update_project(con: PooledMySQLConnection, project_updated: models.ProjectEdit) -> models.Project: + + # Check if project exists + project_original = db_get_project(con, project_updated.id) + + # Change name if requested + if project_original.name != project_updated.name: + update_project_stmnt = MySQLStatementBuilder(con) + res, row_count = update_project_stmnt\ + .update(PROJECTS_TABLE, "name = %s", [project_updated.name])\ + .where("id = %s", [project_updated.id])\ + .execute(fetch_type=FetchType.FETCH_NONE, return_affected_rows=True) + if row_count != 1: + raise NoChangeException + + # Add participants if requested + db_add_participants(con, project_updated.id, project_updated.participants_to_add, check_project_exists=False) + + # Remove participants if requested + db_delete_participants(con, project_updated.id, project_updated.participants_to_remove, check_project_exists=False) + + # Add subprojects if requested + db_update_subprojects_project_association(con, project_updated.subprojects_to_add, project_updated.id, overwrite=False) + + # Remove subprojects if requested + db_clear_subproject_project_association(con, project_updated.subprojects_to_remove, project_updated.id) + + # Return project + return db_get_project(con, project_updated.id) + + def db_get_project_exists(connection: PooledMySQLConnection, project_id: int) -> bool: """ Convenience function for asserting that a project exists. Faster than getting and building the entire project. diff --git a/sedbackend/apps/core/users/implementation.py b/sedbackend/apps/core/users/implementation.py index 09d3a36d..f13deca7 100644 --- a/sedbackend/apps/core/users/implementation.py +++ b/sedbackend/apps/core/users/implementation.py @@ -1,4 +1,5 @@ from typing import List +import re from fastapi import HTTPException, status, File from fastapi.logger import logger @@ -26,10 +27,17 @@ def impl_get_users_me(current_user: models.User) -> models.User: ) -def impl_get_users(segment_length: int, index: int) -> List[models.User]: - with get_connection() as con: - user_list = storage.db_get_user_list(con, segment_length, index) - return user_list +def impl_get_users(segment_length: int, index: int, order_by='username', order_direction='asc') -> List[models.User]: + try: + with get_connection() as con: + user_list = storage.db_get_user_list(con, segment_length, index, order_by=order_by, + order_direction=order_direction) + return user_list + except ValueError as err: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(err) + ) def impl_get_users_with_id(user_ids: List[int]) -> List[models.User]: @@ -167,6 +175,36 @@ def impl_update_user_details(current_user: models.User, return True +def impl_search_users(username: str, full_name: str, limit: int, order_by: str = 'username', + order_direction: str = 'asc') -> List[models.User]: + if len(username) < 3 and len(full_name) < 3: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="At least one of the search terms needs to be more than three characters" + ) + + max_limit = 500 + if limit > max_limit: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Highest allowed limit is {max_limit}" + ) + + # Escape special characters that might break the search term + username = re.escape(username) + full_name = re.escape(full_name) + + try: + with get_connection() as con: + return storage.db_search_users(con, username, full_name, limit=limit, order_by=order_by, + order_direction=order_direction) + except ValueError as err: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(err) + ) + + def check_if_current_user_or_admin(current_user, user_id): if current_user.id == user_id: return True @@ -189,3 +227,8 @@ def check_if_current_user_or_admin(current_user, user_id): ) return True + + + + + diff --git a/sedbackend/apps/core/users/router.py b/sedbackend/apps/core/users/router.py index 540bb19d..3979926e 100644 --- a/sedbackend/apps/core/users/router.py +++ b/sedbackend/apps/core/users/router.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from fastapi import APIRouter, Depends, Security, File, Request @@ -14,8 +14,9 @@ summary="Lists all users", description="Produces a list of users in alphabetical order", response_model=List[models.User]) -async def get_users(segment_length: int, index: int): - return impl.impl_get_users(segment_length, index) +async def get_users(segment_length: int, index: int, order_by: Optional[str] = 'username', + order_direction: Optional[str] = 'asc'): + return impl.impl_get_users(segment_length, index, order_by=order_by, order_direction=order_direction) @router.post("", @@ -43,6 +44,15 @@ async def get_users_me(current_user: models.User = Depends(get_current_active_us return impl.impl_get_users_me(current_user) +@router.get("/search", + summary="Search for users", + response_model=List[models.User]) +async def get_search_users(username: Optional[str] = "", full_name: Optional[str] = "", + limit: Optional[int] = 10, order_by: Optional[str] = 'username', + order_direction: str = 'asc'): + return impl.impl_search_users(username, full_name, limit, order_by=order_by, order_direction=order_direction) + + @router.get("/{user_id}", summary="Get user with ID", response_model=models.User) @@ -76,3 +86,4 @@ async def update_user_password(user_id: int, async def update_user_details(user_id: int, update_email_request: models.UpdateDetailsRequest, current_user: models.User = Depends(get_current_active_user)): return impl.impl_update_user_details(current_user, user_id, update_email_request) + diff --git a/sedbackend/apps/core/users/storage.py b/sedbackend/apps/core/users/storage.py index 409f39c2..3398b553 100644 --- a/sedbackend/apps/core/users/storage.py +++ b/sedbackend/apps/core/users/storage.py @@ -1,11 +1,14 @@ from typing import List + +import mysqlsb.exceptions from fastapi.logger import logger from mysql.connector.pooling import PooledMySQLConnection import sedbackend.apps.core.users.exceptions as exc import sedbackend.apps.core.users.models as models from sedbackend.apps.core.authentication.utils import get_password_hash -from sedbackend.libs.mysqlutils import MySQLStatementBuilder, FetchType +from mysqlsb import MySQLStatementBuilder, FetchType, Sort +from mysqlsb.utils import validate_order_request from mysql.connector.errors import Error as SQLError USERS_COLUMNS_SAFE = ['id', 'username', 'email', 'full_name', 'scopes', 'disabled'] # Safe, as it does not contain passwords @@ -46,7 +49,8 @@ def db_get_user_safe_with_id(connection: PooledMySQLConnection, user_id: int) -> return user -def db_get_user_list(connection: PooledMySQLConnection, segment_length: int, index: int) -> List[models.User]: +def db_get_user_list(connection: PooledMySQLConnection, segment_length: int, index: int, + order_by: str = 'username', order_direction: str = 'asc') -> List[models.User]: try: int(segment_length) int(index) @@ -57,9 +61,17 @@ def db_get_user_list(connection: PooledMySQLConnection, segment_length: int, ind except ValueError: raise TypeError + try: + # Order by is not a prepared statement, so we need to validate it for security + (order_by, direction) = validate_order_request(order_by, USERS_COLUMNS_SAFE, order_direction) + except mysqlsb.exceptions.OrderValueException as err: + raise ValueError(str(err)) + + # Build and run statement mysql_statement = MySQLStatementBuilder(connection) rs = mysql_statement\ .select(USERS_TABLE, USERS_COLUMNS_SAFE)\ + .order_by([order_by], order=direction)\ .limit(segment_length)\ .offset(segment_length * index)\ .execute(fetch_type=FetchType.FETCH_ALL, dictionary=True) @@ -168,3 +180,37 @@ def db_update_user_details(connection: PooledMySQLConnection, user_id: int, raise exc.UserNotFoundException return True + + +def db_search_users(connection: PooledMySQLConnection, username_search_str: str, full_name_search_str: str, + limit: int = 100, order_by: str = 'username', order_direction: str = 'asc') -> List[models.User]: + + users = [] + + try: + # Order by is not a prepared statement, so we need to validate it for security + (order_by, direction) = validate_order_request(order_by, USERS_COLUMNS_SAFE, order_direction) + except mysqlsb.exceptions.OrderValueException as err: + raise ValueError(str(err)) + + username_search_stmnt = "(`username` rlike ?)" + full_name_search_stmnt = "(`full_name` rlike ?)" + + if len(username_search_str) == 0: + username_search_str = "." + if len(full_name_search_str) == 0: + full_name_search_stmnt = '(`full_name` rlike ? OR `full_name` IS NULL)' + full_name_search_str = "." + + stmnt = MySQLStatementBuilder(connection) + rs = stmnt.select('users', USERS_COLUMNS_SAFE)\ + .where(f'{username_search_stmnt} AND {full_name_search_stmnt}', [username_search_str, full_name_search_str])\ + .order_by([order_by], order=direction) \ + .limit(limit)\ + .execute(fetch_type=FetchType.FETCH_ALL, dictionary=True) + + for res in rs: + user = models.User(**res) + users.append(user) + + return users diff --git a/sedbackend/apps/cvs/design/storage.py b/sedbackend/apps/cvs/design/storage.py index 878382aa..88efb340 100644 --- a/sedbackend/apps/cvs/design/storage.py +++ b/sedbackend/apps/cvs/design/storage.py @@ -6,7 +6,7 @@ from sedbackend.apps.cvs.vcs.models import ValueDriver from sedbackend.apps.cvs.vcs import storage as vcs_storage from sedbackend.apps.cvs.vcs.storage import CVS_VALUE_DRIVER_COLUMNS, CVS_VALUE_DRIVER_TABLE, populate_value_driver -from sedbackend.libs.mysqlutils import MySQLStatementBuilder, FetchType, Sort +from mysqlsb import MySQLStatementBuilder, FetchType, Sort from sedbackend.apps.cvs.design import models, exceptions DESIGN_GROUPS_TABLE = 'cvs_design_groups' diff --git a/sedbackend/apps/cvs/life_cycle/storage.py b/sedbackend/apps/cvs/life_cycle/storage.py index 33612e01..f09d6e78 100644 --- a/sedbackend/apps/cvs/life_cycle/storage.py +++ b/sedbackend/apps/cvs/life_cycle/storage.py @@ -1,7 +1,7 @@ from fastapi.logger import logger from mysql.connector.pooling import PooledMySQLConnection -from sedbackend.libs.mysqlutils import MySQLStatementBuilder, FetchType, Sort +from mysqlsb import MySQLStatementBuilder, FetchType, Sort from sedbackend.apps.cvs.life_cycle import exceptions, models from sedbackend.apps.cvs.vcs import storage as vcs_storage, exceptions as vcs_exceptions from mysql.connector import Error diff --git a/sedbackend/apps/cvs/link_design_lifecycle/storage.py b/sedbackend/apps/cvs/link_design_lifecycle/storage.py index 10baaf3c..f8fa467f 100644 --- a/sedbackend/apps/cvs/link_design_lifecycle/storage.py +++ b/sedbackend/apps/cvs/link_design_lifecycle/storage.py @@ -8,7 +8,7 @@ from sedbackend.apps.cvs.project.implementation import get_cvs_project from sedbackend.apps.cvs.vcs.implementation import get_vcs from sedbackend.apps.cvs.link_design_lifecycle import models, exceptions -from sedbackend.libs.mysqlutils.builder import FetchType, MySQLStatementBuilder +from mysqlsb import FetchType, MySQLStatementBuilder from sedbackend.apps.cvs.market_input import implementation as market_impl from sedbackend.apps.cvs.design import implementation as design_impl diff --git a/sedbackend/apps/cvs/market_input/storage.py b/sedbackend/apps/cvs/market_input/storage.py index 9f25012d..4ddfb802 100644 --- a/sedbackend/apps/cvs/market_input/storage.py +++ b/sedbackend/apps/cvs/market_input/storage.py @@ -3,7 +3,7 @@ from fastapi.logger import logger from mysql.connector.pooling import PooledMySQLConnection -from sedbackend.libs.mysqlutils import MySQLStatementBuilder, FetchType, Sort +from mysqlsb import MySQLStatementBuilder, FetchType, Sort from sedbackend.apps.cvs.market_input import models, exceptions from sedbackend.apps.cvs.vcs import storage as vcs_storage, implementation as vcs_impl from sedbackend.apps.cvs.project import exceptions as project_exceptions diff --git a/sedbackend/apps/cvs/project/storage.py b/sedbackend/apps/cvs/project/storage.py index 2e115441..0f4ff9f1 100644 --- a/sedbackend/apps/cvs/project/storage.py +++ b/sedbackend/apps/cvs/project/storage.py @@ -5,7 +5,7 @@ from sedbackend.apps.core.users.storage import db_get_user_safe_with_id from sedbackend.apps.cvs.project import models as models, exceptions as exceptions from sedbackend.libs.datastructures.pagination import ListChunk -from sedbackend.libs.mysqlutils import MySQLStatementBuilder, Sort, FetchType +from mysqlsb import MySQLStatementBuilder, Sort, FetchType import sedbackend.apps.core.projects.models as proj_models import sedbackend.apps.core.projects.storage as proj_storage diff --git a/sedbackend/apps/cvs/simulation/exceptions.py b/sedbackend/apps/cvs/simulation/exceptions.py index 321021e5..0b526be0 100644 --- a/sedbackend/apps/cvs/simulation/exceptions.py +++ b/sedbackend/apps/cvs/simulation/exceptions.py @@ -51,4 +51,8 @@ class FlowProcessNotFoundException(Exception): pass class SimSettingsNotFoundException(Exception): - pass \ No newline at end of file + pass + + +class NoTechnicalProcessException(Exception): + pass diff --git a/sedbackend/apps/cvs/simulation/storage.py b/sedbackend/apps/cvs/simulation/storage.py index 03846fb2..aacafd17 100644 --- a/sedbackend/apps/cvs/simulation/storage.py +++ b/sedbackend/apps/cvs/simulation/storage.py @@ -14,7 +14,7 @@ from typing import List from sedbackend.apps.cvs.design.implementation import get_design -from sedbackend.libs.mysqlutils.builder import FetchType, MySQLStatementBuilder +from mysqlsb import FetchType, MySQLStatementBuilder from sedbackend.libs.formula_parser.parser import NumericStringParser from sedbackend.libs.formula_parser import expressions as expr diff --git a/sedbackend/apps/cvs/vcs/storage.py b/sedbackend/apps/cvs/vcs/storage.py index 6ee28fd3..bcb7816f 100644 --- a/sedbackend/apps/cvs/vcs/storage.py +++ b/sedbackend/apps/cvs/vcs/storage.py @@ -9,7 +9,7 @@ from sedbackend.apps.cvs.life_cycle import storage as life_cycle_storage, models as life_cycle_models from sedbackend.libs.datastructures.pagination import ListChunk from sedbackend.apps.core.users import exceptions as user_exceptions -from sedbackend.libs.mysqlutils import MySQLStatementBuilder, Sort, FetchType +from mysqlsb import MySQLStatementBuilder, Sort, FetchType DEBUG_ERROR_HANDLING = True # Set to false in production diff --git a/sedbackend/libs/mysqlutils/__init__.py b/sedbackend/libs/mysqlutils/__init__.py deleted file mode 100644 index 093a91c4..00000000 --- a/sedbackend/libs/mysqlutils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .builder import MySQLStatementBuilder, FetchType -from .statements import Sort -from .utils import exclude_cols diff --git a/sedbackend/libs/mysqlutils/builder.py b/sedbackend/libs/mysqlutils/builder.py deleted file mode 100644 index 658ea5fc..00000000 --- a/sedbackend/libs/mysqlutils/builder.py +++ /dev/null @@ -1,197 +0,0 @@ -from .statements import * -from typing import Any, List, Optional, Tuple -from fastapi.logger import logger -from enum import Enum - - -class FetchType(Enum): - """ - Used to determine how many rows should be fetched by a MySQL statement/query - """ - - FETCH_ONE = "Fetch_One" - FETCH_ALL = "Fetch_All" - FETCH_NONE = "Fetch_None" - - -class MySQLStatementBuilder: - """ - Assists in building simple MySQL queries and statements. Does not need to be closed. - It automatically closes the MySQL cursor. - """ - - def __init__(self, connection): - self.con = connection - self.query = "" - self.values = [] - self.last_insert_id = None - self.default_fetchtype = FetchType.FETCH_NONE - - def insert(self, table: str, columns: List[str]): - """ - Create a prepared insert statement - - :param table: - :param columns: - :return: - """ - self.query += create_insert_statement(table, columns) - return self - - def set_values(self, values: List[str]): - self.query += create_prepared_values_statement(len(values)) - self.values.extend(values) - return self - - def select(self, table: str, columns: List[str]): - """ - Create a select statement - - :param table: - :param columns: - :return: - """ - - self.query += create_select_statement(table, columns) - return self - - def count(self, table: str): - self.query += create_count_statement(table) - self.default_fetchtype = FetchType.FETCH_ONE - return self - - def update(self, table: str, set_statement, values): - self.query += create_update_statement(table, set_statement) - self.values.extend(values) - return self - - def delete(self, table: str): - self.query += create_delete_statement(table) - return self - - def order_by(self, columns: List[str], order: Sort = None): - self.query += create_order_by_statement(columns, order) - return self - - def offset(self, offset_count: int): - self.query += create_offset_statement(offset_count) - return self - - def limit(self, limit_count: int): - self.query += create_limit_statement(limit_count) - return self - - def inner_join(self, target_table, join_statement): - self.query += create_inner_join_statement(target_table, join_statement) - return self - - def where(self, condition, condition_values: List[Any]): - """ - Create prepared WHERE statement - :param condition: Should be a prepared condition. Use %s or ? to represent variables - :param condition_values: List of condition variables (switches out the %s and ? prepared placeholders) - :return: - """ - - self.query += create_prepared_where_statement(condition) - self.values.extend(condition_values) - return self - - @staticmethod - def placeholder_array(number_of_elements): - """ - Creates an array with N elements, where each element is "%s" - :param number_of_elements: - :return: - """ - placeholder_array = ['%s'] * number_of_elements # Make an array with N '%s' elements - return f'({",".join(placeholder_array)})' # Return that as a SQL array in string format - - def execute(self, - fetch_type: Optional[FetchType] = None, - dictionary: bool = False, - return_affected_rows: bool = False, - no_logs: bool = False): - """ - Executes constructed MySQL query. Does not need to be closed (closes automatically). - - :param no_logs: If performing sensitive operations, then logs should not be saved. Setting this to True will ensure the operation is not recorded in detail. - :param dictionary: boolean. Default is False. Converts response to dictionaries - :param fetch_type: FetchType.FETCH_NONE by default - :param return_affected_rows: When deleting rows, the amount of rows deleted may be returned if this is true - :return: None by default, but can be changed by setting keyword param "fetch_type" - """ - if fetch_type is None and self.default_fetchtype is not None: - fetch_type = self.default_fetchtype - - if fetch_type is None: - fetch_type = FetchType.FETCH_NONE - - if no_logs is False: - logger.debug(f'Executing query "{self.query}" with values "{self.values}". fetch_type={fetch_type}') - - with self.con.cursor(prepared=True) as cursor: - cursor.execute(self.query, self.values) - self.last_insert_id = cursor.lastrowid - - # Determine what the query should return - if fetch_type is FetchType.FETCH_ONE: - res = cursor.fetchone() - - # This is awful. But, since we can't combine prepared cursors with buffered cursors this is necessary - if res is not None: - while cursor.fetchone() is not None: - pass - - elif fetch_type is FetchType.FETCH_ALL: - res = cursor.fetchall() - elif fetch_type is FetchType.FETCH_NONE: - res = None - else: - res = None - - # Convert result to dictionary (or, array of dictionaries) if requested. Skip if there isn't a result - if dictionary is True and res is not None: - - # Format response depending on fetch type - if fetch_type in [FetchType.FETCH_ALL]: - dict_array = [] - - for row in res: - dict_array.append(dict(zip(cursor.column_names, row))) - - res = dict_array - - elif fetch_type is FetchType.FETCH_ONE: - res = dict(zip(cursor.column_names, res)) - - # Finally, return results - if return_affected_rows is True: - return res, cursor.rowcount - else: - return res - - def execute_procedure(self, procedure: str, args: List) -> List[List[Any]]: - """ - Execute a stored procedure. May return multiple result sets depending on the procedure. - :param procedure: Name of stored procedure - :param args: List of arguments - :return: List of result sets - """ - logger.debug(f'executing stored procedure "{procedure}" with arguments {args}') - - with self.con.cursor(dictionary=True) as cursor: - cursor.callproc(procedure, args=args) - - result_sets = [] - - for recordset in cursor.stored_results(): - column_names = recordset.column_names - res_list = [] - for row in recordset: - row_dict = dict(zip(column_names, row)) - res_list.append(row_dict) - - result_sets.append(res_list) - - return result_sets diff --git a/sedbackend/libs/mysqlutils/statements.py b/sedbackend/libs/mysqlutils/statements.py deleted file mode 100644 index 2b58eea6..00000000 --- a/sedbackend/libs/mysqlutils/statements.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import List -from enum import Enum - - -class Sort(Enum): - ASCENDING = 'ASC' - DESCENDING = 'DESC' - - -def create_insert_statement(table: str, columns: List[str], backticks=True): - - if backticks: - insert_cols_str = ', '.join(wrap_in_backticks(columns)) # `col1`, `col2`, `col3`, .. - else: - insert_cols_str = ', '.join(columns) # col1, col2, col3, .. - - query = f"INSERT INTO {table} ({insert_cols_str}) " - return query - - -def create_select_statement(table, columns: List[str]): - return f"SELECT {','.join(wrap_in_backticks(columns))} FROM {table} " # SELECT col1, col2 FROM table - - -def create_count_statement(table): - return f"SELECT COUNT(*) as count FROM {table} " - - -def create_delete_statement(table: str): - return f"DELETE FROM {table} " - - -def create_update_statement(table: str, set_statement: str): - return f"UPDATE {table} SET {set_statement} " - - -def create_prepared_values_statement(count: int): - placeholder_array = ['%s'] * count - placeholder_str = ', '.join(placeholder_array) - return f"VALUES ({placeholder_str}) " - - -def create_prepared_where_statement(condition): - return f"WHERE {condition} " - - -def create_limit_statement(n): - return f"LIMIT {n} " - - -def create_offset_statement(n): - return f"OFFSET {n} " - - -def create_order_by_statement(columns: List[str], order: Sort = None): - cols_str = ', '.join(wrap_in_backticks(columns)) - if order: - return f"ORDER BY {cols_str} {order.value} " - else: - return f"ORDER BY {cols_str} " - - -def create_inner_join_statement(target_table, join_statement): - return f"INNER JOIN {target_table} ON {join_statement} " - - -def wrap_in_backticks(array: List[str]): - """ - Wraps each element in back-ticks. This is useful for escaping reserved key-words, - and future-proofing column/table names. - :param array: Array of table/column names - :return: - """ - new_array = [] - for element in array: - element = element.replace('.', '`.`') - new_array.append("`{}`".format(element)) - - return new_array diff --git a/sedbackend/libs/mysqlutils/utils.py b/sedbackend/libs/mysqlutils/utils.py deleted file mode 100644 index 289d9ca7..00000000 --- a/sedbackend/libs/mysqlutils/utils.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import List - - -def exclude_cols(column_list: List[str], exclude_list: List[str]): - """ - Takes a list of strings, and excludes all entries in the exlclude list. - Returns a copy of the list, but without the excluded entries. - Does not change the inserted list. - :return: - """ - column_list_copy = column_list[:] - - for exclude_col in exclude_list: - if exclude_col in column_list: - column_list_copy.remove(exclude_col) - else: - raise ValueError("Excluded column could not be found in column list.") - - return column_list_copy diff --git a/sedbackend/main.py b/sedbackend/main.py index 97f151af..86193b76 100644 --- a/sedbackend/main.py +++ b/sedbackend/main.py @@ -15,7 +15,7 @@ app = FastAPI( title="SED lab API", description="The SED lab API contains all HTTP operations available within the SED lab application.", - version="1.0.3", + version="1.0.4", ) app.include_router(api.router, prefix="/api") diff --git a/sedbackend/setup.py b/sedbackend/setup.py index c649fdb8..af94fc31 100644 --- a/sedbackend/setup.py +++ b/sedbackend/setup.py @@ -3,10 +3,14 @@ from logging.handlers import TimedRotatingFileHandler import tempfile +import mysqlsb from fastapi import Request from fastapi.logger import logger from starlette.responses import Response +# Set database logger +mysqlsb.Configuration.logger = logger + def config_default_logging(): """ diff --git a/sql/V230522_release_1_0_4.sql b/sql/V230522_release_1_0_4.sql new file mode 100644 index 00000000..27c92cf5 --- /dev/null +++ b/sql/V230522_release_1_0_4.sql @@ -0,0 +1,30 @@ +# Additions and modifications to the database as of sed-backend version 1.0.4 + +# Map files to subprojects +CREATE TABLE IF NOT EXISTS `seddb`.`files_subprojects_map` ( + `id` INT UNSIGNED NOT NULL AUTO_INCREMENT, + `file_id` INT UNSIGNED NOT NULL, + `subproject_id` INT UNSIGNED NOT NULL, + PRIMARY KEY (`id`), + UNIQUE INDEX `id_UNIQUE` (`id` ASC) VISIBLE, + CONSTRAINT `remove_subproject_to_file_map_on_file_removal` + FOREIGN KEY (`file_id`) + REFERENCES `seddb`.`files` (`id`) + ON DELETE CASCADE + ON UPDATE NO ACTION, + CONSTRAINT `remove_subproject_to_file_map_on_subproject_removal` + FOREIGN KEY (`subproject_id`) + REFERENCES `seddb`.`projects_subprojects` (`id`) + ON DELETE CASCADE + ON UPDATE NO ACTION + ); + +# Add name and date fields to subprojects to simplify identification and searchability +ALTER TABLE `seddb`.`projects_subprojects` + ADD COLUMN `name` VARCHAR(255) NOT NULL DEFAULT 'Unnamed sub-project' AFTER `id`, + ADD COLUMN `datetime_created` DATETIME(3) NOT NULL DEFAULT NOW(3) AFTER `owner_id`; + + +# Add date field to projects to simplify identification and searchability +ALTER TABLE `seddb`.`projects` + ADD COLUMN `datetime_created` DATETIME(3) NOT NULL DEFAULT NOW(3) AFTER `name`; diff --git a/tests/apps/core/files/test_files.py b/tests/apps/core/files/test_files.py new file mode 100644 index 00000000..dff0a9bc --- /dev/null +++ b/tests/apps/core/files/test_files.py @@ -0,0 +1,142 @@ +import tempfile +import tests.apps.core.projects.testutils as tu_proj +import tests.apps.core.users.testutils as tu_users +import tests.apps.core.files.testutils as tu + +import sedbackend.apps.core.files.implementation as impl +import sedbackend.apps.core.files.models as models +import sedbackend.apps.core.users.implementation as impl_users +from sedbackend.apps.core.projects.models import AccessLevel + + +def test_get_file(client, std_headers, std_user): + #Setup + current_user = impl_users.impl_get_user_with_username(std_user.username) + project = tu_proj.seed_random_project(current_user.id) + subp = tu_proj.seed_random_subproject(current_user.id, project.id) + + + tmp_file = tempfile.SpooledTemporaryFile() + tmp_file.write(b"Hello World!") + + post_file = models.StoredFilePost( + filename="hello", + owner_id=current_user.id, + extension=".txt", + file_object=tmp_file, + subproject_id=subp.id + ) + saved_file = impl.impl_save_file(post_file) + + #Act + res = client.get(f"/api/core/files/{saved_file.id}/download", + headers=std_headers) + + #Assert + assert res.status_code == 200 + + #Cleanup + tu.delete_files([saved_file], [current_user]) + tu_proj.delete_subprojects([subp]) + tu_proj.delete_projects([project]) + + +def test_delete_file_admin(client, admin_headers, admin_user): + #Setup + std_user = tu_users.seed_random_user(admin=False, disabled=False) + adm_user = impl_users.impl_get_user_with_username(admin_user.username) + project = tu_proj.seed_random_project(std_user.id, {adm_user.id: AccessLevel.ADMIN}) + subp = tu_proj.seed_random_subproject(std_user.id, project.id) + + + tmp_file = tempfile.SpooledTemporaryFile() + tmp_file.write(b"Hello World!") + + post_file = models.StoredFilePost( + filename="hello", + owner_id=std_user.id, + extension=".txt", + file_object=tmp_file, + subproject_id=subp.id + ) + saved_file = impl.impl_save_file(post_file) + + #Act + res = client.delete(f"/api/core/files/{saved_file.id}/delete", + headers=admin_headers) + + #Assert + assert res.status_code == 200 + + #Cleanup + tu_proj.delete_subprojects([subp]) + tu_proj.delete_projects([project]) + tu_users.delete_users([std_user]) + + + + +def test_delete_file_standard(client, std_headers, std_user): + #Setup + file_owner = tu_users.seed_random_user(admin=False, disabled=False) + current_user = impl_users.impl_get_user_with_username(std_user.username) + project = tu_proj.seed_random_project(file_owner.id, {current_user.id: AccessLevel.READONLY}) + subp = tu_proj.seed_random_subproject(file_owner.id, project.id) + + + tmp_file = tempfile.SpooledTemporaryFile() + tmp_file.write(b"Hello World!") + + post_file = models.StoredFilePost( + filename="hello", + owner_id=file_owner.id, + extension=".txt", + file_object=tmp_file, + subproject_id=subp.id + ) + saved_file = impl.impl_save_file(post_file) + + #Act + res = client.delete(f"/api/core/files/{saved_file.id}/delete", + headers=std_headers) + + #Assert + assert res.status_code == 403 #403 forbidden, should not be able to access resource + + #Cleanup + tu_proj.delete_subprojects([subp]) + tu_proj.delete_projects([project]) + tu_users.delete_users([file_owner]) + + + +def test_delete_file_owner(client, std_headers, std_user): + #Setup + current_user = impl_users.impl_get_user_with_username(std_user.username) + project = tu_proj.seed_random_project(current_user.id) + subp = tu_proj.seed_random_subproject(current_user.id, project.id) + + + tmp_file = tempfile.SpooledTemporaryFile() + tmp_file.write(b"Hello World!") + + post_file = models.StoredFilePost( + filename="hello", + owner_id=current_user.id, + extension=".txt", + file_object=tmp_file, + subproject_id=subp.id + ) + saved_file = impl.impl_save_file(post_file) + + #Act + res = client.delete(f"/api/core/files/{saved_file.id}/delete", + headers=std_headers) + + #Assert + assert res.status_code == 200 + + #Cleanup + tu_proj.delete_subprojects([subp]) + tu_proj.delete_projects([project]) + \ No newline at end of file diff --git a/tests/apps/core/files/testutils.py b/tests/apps/core/files/testutils.py new file mode 100644 index 00000000..49dd78f6 --- /dev/null +++ b/tests/apps/core/files/testutils.py @@ -0,0 +1,8 @@ +from typing import List +import sedbackend.apps.core.files.implementation as impl +import sedbackend.apps.core.files.models as models +import sedbackend.apps.core.users.models as user_models + +def delete_files(files: List[models.StoredFileEntry], users: List[user_models.User]): + for i,file in enumerate(files): + impl.impl_delete_file(file.id, users[i].id) \ No newline at end of file diff --git a/tests/apps/core/projects/test_projects.py b/tests/apps/core/projects/test_projects.py index 50294640..9e52ecba 100644 --- a/tests/apps/core/projects/test_projects.py +++ b/tests/apps/core/projects/test_projects.py @@ -44,7 +44,7 @@ def test_get_all_projects_as_admin(client, admin_headers, admin_user): max_projects = 30 current_user = impl_users.impl_get_user_with_username(admin_user.username) seeded_projects = tu_projects.seed_random_projects(current_user.id, amount=r.randint(5, max_projects)) - amount_of_projects = len(impl.impl_get_projects()) + amount_of_projects = len(impl.impl_get_projects(current_user.id)) # Act res = client.get('/api/core/projects/all', headers=admin_headers) # Assert @@ -310,3 +310,109 @@ def test_change_name_as_non_admin(client, std_headers, std_user): # Cleanup impl.impl_delete_project(p.id) impl_users.impl_delete_user_from_db(owner_user.id) + + +def test_add_participants_as_admin(client, std_headers, std_user): + # Setup + current_user = impl_users.impl_get_user_with_username(std_user.username) + participant_1 = tu_users.seed_random_user(admin=False, disabled=False) + participant_2 = tu_users.seed_random_user(admin=False, disabled=False) + participant_3 = tu_users.seed_random_user(admin=False, disabled=False) + + p = tu_projects.seed_random_project(current_user.id) + + participants_access_dict = { + participant_1.id: models.AccessLevel.ADMIN, + participant_2.id: models.AccessLevel.EDITOR, + participant_3.id: models.AccessLevel.READONLY + } + + # Act + impl.impl_post_participants(p.id, participants_access_dict) + p_updated = impl.impl_get_project(p.id) + + # Assert + for participant in p_updated.participants: + + # Check owner + if participant.id == current_user.id: + assert p_updated.participants_access[participant.id] == models.AccessLevel.OWNER + continue + + # Check other participants + assert participant.id in participants_access_dict.keys() + assert p_updated.participants_access[participant.id] == participants_access_dict[participant.id] + + # Cleanup + impl.impl_delete_project(p.id) + impl_users.impl_delete_user_from_db(participant_1.id) + impl_users.impl_delete_user_from_db(participant_2.id) + impl_users.impl_delete_user_from_db(participant_3.id) + + +def test_update_project(client, std_headers, std_user): + # Setup + current_user = impl_users.impl_get_user_with_username(std_user.username) + old_participant_1 = tu_users.seed_random_user(False, False) + old_participant_2 = tu_users.seed_random_user(False, False) + + new_participant_1 = tu_users.seed_random_user(False, False) + new_participant_2 = tu_users.seed_random_user(False, False) + new_participant_3 = tu_users.seed_random_user(False, False) + + project = tu_projects.seed_random_project(current_user.id, participants={ + old_participant_1.id: models.AccessLevel.EDITOR, + old_participant_2.id: models.AccessLevel.READONLY + }) + + # Add three subprojects to the project + old_subproject_1 = tu_projects.seed_random_subproject(current_user.id, project.id) + old_subproject_2 = tu_projects.seed_random_subproject(current_user.id, project.id) + old_subproject_3 = tu_projects.seed_random_subproject(current_user.id, project.id) + + new_subproject_1 = tu_projects.seed_random_subproject(new_participant_3.id, None) + new_subproject_2 = tu_projects.seed_random_subproject(current_user.id, None) + + new_name = tu.random_str(5, 50) + + # Act + res_before = client.get(f'/api/core/projects/{project.id}', headers=std_headers) + p_before_json = res_before.json() + + res_after = client.put(f'/api/core/projects/{project.id}', headers=std_headers, json={ + "id": project.id, + "name": new_name, + "participants_to_add": { + new_participant_1.id: models.AccessLevel.ADMIN.value, + new_participant_2.id: models.AccessLevel.EDITOR.value, + new_participant_3.id: models.AccessLevel.READONLY.value + }, + "participants_to_remove": [old_participant_1.id, old_participant_2.id], + "subprojects_to_add": [new_subproject_1.id, new_subproject_2.id], + "subprojects_to_remove": [old_subproject_1.id, old_subproject_2.id, old_subproject_3.id] + }) + p_after_json = res_after.json() + + # Assert - before + assert p_before_json["id"] == project.id + assert len(p_before_json["participants"]) == 3 + assert len(p_before_json["subprojects"]) == 3 + assert p_before_json["name"] == project.name + # Assert - after + assert res_after.status_code == 200 + assert p_after_json["id"] == project.id + assert p_after_json["name"] == new_name + assert len(p_after_json["participants"]) == 4 + assert len(p_after_json["subprojects"]) == 2 + + # Cleanup + tu_users.delete_users([new_participant_1, new_participant_2, new_participant_3, old_participant_1, old_participant_2]) + + old_subproject_1.project_id = None + old_subproject_2.project_id = None + old_subproject_3.project_id = None + new_subproject_1.project_id = project.id + new_subproject_2.project_id = project.id + + tu_projects.delete_subprojects([old_subproject_1, old_subproject_2, old_subproject_3, new_subproject_1, new_subproject_2]) + tu_projects.delete_projects([project]) diff --git a/tests/apps/core/users/test_users.py b/tests/apps/core/users/test_users.py index 2419f6a7..d2d5966d 100644 --- a/tests/apps/core/users/test_users.py +++ b/tests/apps/core/users/test_users.py @@ -117,11 +117,101 @@ def test_get_user_me(client, std_headers, std_user): def test_get_users_unauthenticated(client): # Act - res = client.get("/api/core/users/users") + res = client.get("/api/core/users") # Assert assert res.status_code == 401 +def test_get_users(client, std_headers): + # Setup + users = [tu_users.seed_random_user(), tu_users.seed_random_user(), tu_users.seed_random_user()] + + # Act + res = client.get("/api/core/users?segment_length=100&index=0&order_by=id&order_direction=desc", headers=std_headers) + + # Assert + assert res.status_code == 200 + assert len(res.json()) > 0 + ids = [] + for u in res.json(): + ids.append(u["id"]) + for u in users: + assert u.id in ids + + # Clean + tu_users.delete_users(users) + + +def test_get_users_forbidden(client, std_headers): + # Act + res_1 = client.get(f'/api/core/users?segment_length=100&index=0&order_by=username&order_direction=desc', headers=std_headers) + res_2 = client.get(f'/api/core/users?segment_length=100&index=0&order_by=password&order_direction=desc', headers=std_headers) + res_3 = client.get(f'/api/core/users?segment_length=100&index=0&order_by=password&order_direction=blah', headers=std_headers) + res_4 = client.get(f'/api/core/users?segment_length=100&index=0&order_by=username&order_direction=blah', headers=std_headers) + + # Assert + assert res_1.status_code == 200 + assert res_2.status_code == 400 + assert res_3.status_code == 400 + assert res_4.status_code == 400 + + +def test_search_users(client, std_headers): + # Setup + prefix_1 = 'search_test_' + tu.random_str(3, 3) + prefix_2 = prefix_1 + '2' + tu.random_str(3, 3) + prefix_3 = prefix_2 + '3' + tu.random_str(3, 3) + u1 = tu_users.seed_random_user(name_prefix=prefix_1) + u2 = tu_users.seed_random_user(name_prefix=prefix_2) + u3 = tu_users.seed_random_user(name_prefix=prefix_3) + + # Act + res_1 = client.get(f'/api/core/users/search?username={prefix_1}&order_by=username&order_direction=desc', headers=std_headers) + res_2 = client.get(f'/api/core/users/search?username={prefix_2}', headers=std_headers) + res_3 = client.get(f'/api/core/users/search?username={prefix_3}&order_by=id&order_direction=asc', headers=std_headers) + + # Assert + assert res_1.status_code == 200 and res_2.status_code == 200 and res_3.status_code == 200 + assert len(res_1.json()) == 3 + assert len(res_2.json()) == 2 + assert len(res_3.json()) == 1 + + # Clean + tu_users.delete_users([u1, u2, u3]) + + +def test_search_users_unauthenticated(client): + # Setup + prefix = tu.random_str(8, 8) + user = tu_users.seed_random_user(name_prefix=prefix) + + # Act + res = client.get(f'/api/core/users/search?username={prefix}') + + # Assert + assert res.status_code == 401 + + # Clean + tu_users.delete_users([user]) + + +def test_search_users_forbidden(client, std_headers): + # Setup + prefix = tu.random_str(8, 8) + user = tu_users.seed_random_user(name_prefix=prefix) + + # Act + res_1 = client.get(f'/api/core/users/search?username={prefix}&order_by=password', headers=std_headers) + res_2 = client.get(f'/api/core/users/search?username={prefix}&order_direction=blah', headers=std_headers) + + # Assert + assert res_1.status_code == 400 # Should be forbidden + assert res_2.status_code == 400 # Should be forbidden + + # Clean + tu_users.delete_users([user]) + + def test_post_user_unauthenticated(client): # Act res = client.post("/api/core/users", diff --git a/tests/apps/core/users/testutils.py b/tests/apps/core/users/testutils.py index 9bdbcdca..5ee36816 100644 --- a/tests/apps/core/users/testutils.py +++ b/tests/apps/core/users/testutils.py @@ -5,8 +5,11 @@ import tests.testutils as tu -def random_user_post(admin=False, disabled=False) -> models.UserPost: - random_name = 'test_user_' + tu.random_str(5, 15) +RANDOM_NAME_PREFIX = 'test_user_' + + +def random_user_post(admin=False, disabled=False, name_prefix=RANDOM_NAME_PREFIX) -> models.UserPost: + random_name = name_prefix + tu.random_str(5, 15) random_email = random_name + "@sed-mock-email.com" random_password = tu.random_str(5, 20) @@ -25,8 +28,8 @@ def random_user_post(admin=False, disabled=False) -> models.UserPost: return user -def seed_random_user(admin=False, disabled=False) -> models.User: - user_post = random_user_post(admin=admin, disabled=disabled) +def seed_random_user(admin=False, disabled=False, name_prefix=RANDOM_NAME_PREFIX) -> models.User: + user_post = random_user_post(admin=admin, disabled=disabled, name_prefix=name_prefix) user = impl.impl_post_user(user_post) return user diff --git a/tests/apps/cvs/simulation/test_sim_multiprocessing.py b/tests/apps/cvs/simulation/test_sim_multiprocessing.py index e37822f0..abbda075 100644 --- a/tests/apps/cvs/simulation/test_sim_multiprocessing.py +++ b/tests/apps/cvs/simulation/test_sim_multiprocessing.py @@ -2,6 +2,7 @@ import tests.apps.cvs.testutils as tu import testutils as sim_tu import sedbackend.apps.core.users.implementation as impl_users +import sedbackend.apps.cvs.simulation.exceptions as sim_exceptions def test_run_single_monte_carlo_sim(client, std_headers, std_user): #Setup @@ -212,13 +213,15 @@ def test_run_mc_sim_both_flows(client, std_headers, std_user): tu.delete_vd_from_user(current_user.id) -''' + def test_run_mc_sim_rate_invalid_order(client, std_headers, std_user): #Setup current_user = impl_users.impl_get_user_with_username(std_user.username) project, vcs, design_group, design, settings = sim_tu.setup_single_simulation(current_user.id) - tu.edit_rate_order_formulas(project.id, vcs.id, design_group.id) + first_tech_process = tu.edit_rate_order_formulas(project.id, vcs.id, design_group.id) + if first_tech_process is None: + raise sim_exceptions.NoTechnicalProcessException settings.monte_carlo = False #Act @@ -240,4 +243,3 @@ def test_run_mc_sim_rate_invalid_order(client, std_headers, std_user): tu.delete_VCS_with_ids(project.id, [vcs.id]) tu.delete_project_by_id(project.id, current_user.id) tu.delete_vd_from_user(current_user.id) -''' \ No newline at end of file diff --git a/tests/apps/cvs/simulation/test_simulation.py b/tests/apps/cvs/simulation/test_simulation.py index 43d288ef..d7555718 100644 --- a/tests/apps/cvs/simulation/test_simulation.py +++ b/tests/apps/cvs/simulation/test_simulation.py @@ -295,7 +295,7 @@ def test_run_sim_invalid_proj(client, std_headers, std_user): #Assert assert res.status_code == 404 - assert res.json() == {'detail': 'Sub project not found'} + assert res.json() == {'detail': 'Sub-project not found.'} #Should probably assert some other stuff about the output to ensure that it is correct. diff --git a/tests/apps/cvs/simulation/testutils.py b/tests/apps/cvs/simulation/testutils.py index 6f4f1f27..5d8dfae5 100644 --- a/tests/apps/cvs/simulation/testutils.py +++ b/tests/apps/cvs/simulation/testutils.py @@ -11,7 +11,7 @@ def setup_single_simulation(user_id) -> Tuple[CVSProject, VCS, DesignGroup, List project = tu.seed_random_project(user_id) vcs = tu.seed_random_vcs(project.id) design_group = tu.seed_random_design_group(project.id) - tu.seed_random_formulas(project.id, vcs.id, design_group.id, user_id, 10) #Also creates the vcs rows + tu.seed_random_formulas(project.id, vcs.id, design_group.id, user_id, 15) #Also creates the vcs rows design = tu.seed_random_designs(project.id, design_group.id, 1) settings = tu.seed_simulation_settings(project.id, [vcs.id], [design[0].id]) diff --git a/tests/apps/cvs/testutils.py b/tests/apps/cvs/testutils.py index 76b3c7eb..a6962dcc 100644 --- a/tests/apps/cvs/testutils.py +++ b/tests/apps/cvs/testutils.py @@ -153,7 +153,10 @@ def random_table_row( subprocess = random_subprocess(project_id, vcs_id) subprocess_id = subprocess.id else: - iso_process_id = random.randint(1, 25) + if random.randint(1, 5) == 1: #Give 1/5 chance to produce non-tech process + iso_process_id = random.randint(1, 14) + else: + iso_process_id = random.randint(15, 25) if stakeholder is None: stakeholder = tu.random_str(5, 50) @@ -556,10 +559,10 @@ def edit_rate_order_formulas(project_id: int, vcs_id: int, design_group_id: int) rows.reverse() # reverse back to find first technical process for row in rows: if row.iso_process is not None: - if row.iso_process.category == 'Technical processes': + if row.iso_process.category == 'Technical processes' and row.id != last_id: return row else: - if row.subprocess.parent_process.category == 'Technical processes': + if row.subprocess.parent_process.category == 'Technical processes' and row.id != last_id: return row