diff --git a/src/bk-user/bkuser/apis/web/data_source/views.py b/src/bk-user/bkuser/apis/web/data_source/views.py index ec868ad28..08234f001 100644 --- a/src/bk-user/bkuser/apis/web/data_source/views.py +++ b/src/bk-user/bkuser/apis/web/data_source/views.py @@ -465,7 +465,6 @@ def post(self, request, *args, **kwargs): operator=request.user.username, overwrite=data["overwrite"], incremental=data["incremental"], - # FIXME (su) 本地数据源导入也要改成异步行为,但是要解决 excel 如何传递的问题 async_run=True, trigger=SyncTaskTrigger.MANUAL, ) diff --git a/src/bk-user/bkuser/apps/sync/managers.py b/src/bk-user/bkuser/apps/sync/managers.py index 890fd0cd0..101a75a27 100644 --- a/src/bk-user/bkuser/apps/sync/managers.py +++ b/src/bk-user/bkuser/apps/sync/managers.py @@ -32,7 +32,7 @@ def __init__(self, data_source: DataSource, sync_options: DataSourceSyncOptions) self.data_source = data_source self.sync_options = sync_options self.sync_timeout = data_source.sync_timeout - self.cache = Cache(CacheEnum.REDIS, CacheKeyPrefixEnum.DATA_SOURCE_ASYNC) + self.cache = Cache(CacheEnum.REDIS, CacheKeyPrefixEnum.DATA_SOURCE_SYNC_RAW_DATA) def execute(self, plugin_init_extra_kwargs: Optional[Dict[str, Any]] = None) -> DataSourceSyncTask: """同步数据源数据到数据库中,注意该方法不可用于 DB 事务中,可能导致异步任务获取 Task 失败""" @@ -53,31 +53,26 @@ def execute(self, plugin_init_extra_kwargs: Optional[Dict[str, Any]] = None) -> ) if self.sync_options.async_run: - workbook = plugin_init_extra_kwargs["workbook"] - with io.BytesIO() as buffer: - workbook.save(buffer) - content = buffer.getvalue() - encoded_data = base64.b64encode(content).decode("utf-8") - task_key = f"data_source: {self.data_source.id}: {task.id}" - self.cache.set(task_key, encoded_data, 2 * self.sync_timeout) - sync_data_source.apply_async(args=[task.id, task_key], soft_time_limit=self.sync_timeout) + if self.data_source.is_local and self.data_source.is_real_type: + self._process_workbook(plugin_init_extra_kwargs, task.id) + sync_data_source.apply_async(args=[task.id, plugin_init_extra_kwargs], soft_time_limit=self.sync_timeout) else: # 同步的方式,不需要序列化/反序列化,因此不需要检查基础类型 DataSourceSyncTaskRunner(task, plugin_init_extra_kwargs).run() return task - @staticmethod - def _ensure_only_basic_type_in_kwargs(kwargs: Dict[str, Any]): - """确保 插件初始化额外参数 中只有基础类型""" - if not kwargs: - return - - for v in kwargs.values(): - if isinstance(v, (int, float, str, bytes, bool, dict, list)): - continue - - raise TypeError("only basic type allowed in plugin_init_extra_kwargs!") + def _process_workbook(self, plugin_init_extra_kwargs, task_id): + workbook = plugin_init_extra_kwargs.get("workbook") + if workbook: + with io.BytesIO() as buffer: + workbook.save(buffer) + content = buffer.getvalue() + encoded_data = base64.b64encode(content).decode("utf-8") + task_key = f"data_source:{self.data_source.id}:{task_id}" + self.cache.set(task_key, encoded_data, 2 * self.sync_timeout) + plugin_init_extra_kwargs["task_key"] = task_key + plugin_init_extra_kwargs.pop("workbook") class TenantSyncManager: diff --git a/src/bk-user/bkuser/apps/sync/tasks.py b/src/bk-user/bkuser/apps/sync/tasks.py index da8a7500b..8481420f8 100644 --- a/src/bk-user/bkuser/apps/sync/tasks.py +++ b/src/bk-user/bkuser/apps/sync/tasks.py @@ -12,6 +12,7 @@ import base64 import logging from io import BytesIO +from typing import Any, Dict from openpyxl import load_workbook @@ -29,24 +30,27 @@ from bkuser.common.task import BaseTask logger = logging.getLogger(__name__) -cache = Cache(CacheEnum.REDIS, CacheKeyPrefixEnum.DATA_SOURCE_ASYNC) +cache = Cache(CacheEnum.REDIS, CacheKeyPrefixEnum.DATA_SOURCE_SYNC_RAW_DATA) @app.task(base=BaseTask, ignore_result=True) -def sync_data_source(task_id: int, task_key: str): +def sync_data_source(task_id: int, plugin_init_extra_kwargs: Dict[str, Any]): """同步数据源数据""" logger.info("[celery] receive data source sync task: %s", task_id) - encoded_data = cache.get(task_key) - if not encoded_data: - logger.error("[celery] data source sync task file not found: %s", task_id) - task = DataSourceSyncTask.objects.get(id=task_id) - task.status = SyncTaskStatus.FAILED.value - task.logs = "data source sync task file not found: %s" % task_id - task.save() - return - cache.delete(task_key) - workbook = load_workbook(filename=BytesIO(base64.b64decode(encoded_data))) - plugin_init_extra_kwargs = {"workbook": workbook} + if plugin_init_extra_kwargs.get("task_key"): + task_key = plugin_init_extra_kwargs["task_key"] + plugin_init_extra_kwargs.pop("task_key") + encoded_data = cache.get(task_key) + if not encoded_data: + logger.error("[celery] data source sync task file not found: %s", task_id) + task = DataSourceSyncTask.objects.get(id=task_id) + task.status = SyncTaskStatus.FAILED.value + task.logs = "data source sync task file not found: %s" % task_id + task.save() + return + cache.delete(task_key) + workbook = load_workbook(filename=BytesIO(base64.b64decode(encoded_data))) + plugin_init_extra_kwargs["workbook"] = workbook task = DataSourceSyncTask.objects.get(id=task_id) DataSourceSyncTaskRunner(task, plugin_init_extra_kwargs).run() diff --git a/src/bk-user/bkuser/common/cache.py b/src/bk-user/bkuser/common/cache.py index 9f5c17e1c..4616e790b 100644 --- a/src/bk-user/bkuser/common/cache.py +++ b/src/bk-user/bkuser/common/cache.py @@ -39,8 +39,8 @@ class CacheKeyPrefixEnum(str, StructuredEnum): VERIFICATION_CODE = "vc" # 用户重置密码用 Token RESET_PASSWORD_TOKEN = "rpt" - # 数据源同步任务 - DATA_SOURCE_ASYNC = "dsa" + # 数据源同步任务原始数据 + DATA_SOURCE_SYNC_RAW_DATA = "dssrd" def _default_key_function(*args, **kwargs): diff --git a/src/bk-user/tests/apis/web/data_source/test_data_source.py b/src/bk-user/tests/apis/web/data_source/test_data_source.py index dd7fd62ee..03e074e8b 100644 --- a/src/bk-user/tests/apis/web/data_source/test_data_source.py +++ b/src/bk-user/tests/apis/web/data_source/test_data_source.py @@ -463,6 +463,5 @@ def test_data_source_import_success(self, api_client, data_source): sync_task = DataSourceSyncTask.objects.get(data_source=data_source) assert response.status_code == status.HTTP_200_OK assert sync_task.status == SyncTaskStatus.SUCCESS - assert DataSource.objects.filter(id=data_source.id).exists() assert DataSourceUser.objects.filter(data_source_id=data_source.id).exists() assert DataSourceDepartment.objects.filter(data_source_id=data_source.id).exists()