From 219b9046b92e6327ae099877d6f38382c520f593 Mon Sep 17 00:00:00 2001 From: Akira Noda <61897166+tsugumi-sys@users.noreply.github.com> Date: Wed, 23 Oct 2024 10:43:32 +0900 Subject: [PATCH] Add JSON format to save data as local artifact Related to #238 Add support for saving wiki500 data in JSON format. * **stocklake/stores/constants.py** - Add `JSON` to `ArtifactFormat` enum. - Update `formats` method to include `JSON`. * **stocklake/wiki_sp500/stores.py** - Add support for saving data in JSON format in `WikiSP500Store` class. - Update `save` method to handle `ArtifactFormat.JSON`. * **tests/wiki_sp500/test_stores.py** - Add test cases for saving data in JSON format. - Verify `WikiSP500Store` class handles `ArtifactFormat.JSON` correctly. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/tsugumi-sys/stocklake/issues/238?shareId=XXXX-XXXX-XXXX-XXXX). --- stocklake/stores/constants.py | 2 +- stocklake/wiki_sp500/stores.py | 6 ++++++ tests/wiki_sp500/test_stores.py | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/stocklake/stores/constants.py b/stocklake/stores/constants.py index e2fa080..433364b 100644 --- a/stocklake/stores/constants.py +++ b/stocklake/stores/constants.py @@ -15,7 +15,7 @@ def __str__(self): class ArtifactFormat(str, Enum): CSV = "csv" - # JSON = "json" + JSON = "json" @staticmethod def formats(): diff --git a/stocklake/wiki_sp500/stores.py b/stocklake/wiki_sp500/stores.py index a7cb3fe..0703c44 100644 --- a/stocklake/wiki_sp500/stores.py +++ b/stocklake/wiki_sp500/stores.py @@ -1,5 +1,6 @@ import os import tempfile +import json from typing import List, Optional from stocklake.core.base_sqlalchemy_store import SQLAlchemyStore @@ -38,6 +39,11 @@ def save( csv_file_path = os.path.join(tmpdir, "wiki_sp500.csv") save_data_to_csv([d.model_dump() for d in data], csv_file_path) repository.save_artifact(csv_file_path) + elif artifact_format == ArtifactFormat.JSON: + json_file_path = os.path.join(tmpdir, "wiki_sp500.json") + with open(json_file_path, "w") as json_file: + json.dump([d.model_dump() for d in data], json_file) + repository.save_artifact(json_file_path) else: raise NotImplementedError() return repository.list_artifacts()[0].path diff --git a/tests/wiki_sp500/test_stores.py b/tests/wiki_sp500/test_stores.py index 3bc9082..f7f6727 100644 --- a/tests/wiki_sp500/test_stores.py +++ b/tests/wiki_sp500/test_stores.py @@ -21,7 +21,7 @@ def wiki_sp500_data(): @pytest.mark.parametrize( - "artifact_format", [None, ArtifactFormat.CSV, "INVALID_FORMAT"] + "artifact_format", [None, ArtifactFormat.CSV, ArtifactFormat.JSON, "INVALID_FORMAT"] ) def test_save_local_artifact_repo(artifact_format, wiki_sp500_data): store = WikiSP500Store()