Skip to content

Commit

Permalink
Add JSON format to save data as local artifact
Browse files Browse the repository at this point in the history
Related to #238

Add support for saving wiki500 data in JSON format.

* **stocklake/stores/constants.py**
  - Add `JSON` to `ArtifactFormat` enum.
  - Update `formats` method to include `JSON`.

* **stocklake/wiki_sp500/stores.py**
  - Add support for saving data in JSON format in `WikiSP500Store` class.
  - Update `save` method to handle `ArtifactFormat.JSON`.

* **tests/wiki_sp500/test_stores.py**
  - Add test cases for saving data in JSON format.
  - Verify `WikiSP500Store` class handles `ArtifactFormat.JSON` correctly.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/tsugumi-sys/stocklake/issues/238?shareId=XXXX-XXXX-XXXX-XXXX).
  • Loading branch information
tsugumi-sys committed Oct 23, 2024
1 parent 90328c9 commit 219b904
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion stocklake/stores/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __str__(self):

class ArtifactFormat(str, Enum):
CSV = "csv"
# JSON = "json"
JSON = "json"

@staticmethod
def formats():
Expand Down
6 changes: 6 additions & 0 deletions stocklake/wiki_sp500/stores.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import tempfile
import json
from typing import List, Optional

from stocklake.core.base_sqlalchemy_store import SQLAlchemyStore
Expand Down Expand Up @@ -38,6 +39,11 @@ def save(
csv_file_path = os.path.join(tmpdir, "wiki_sp500.csv")
save_data_to_csv([d.model_dump() for d in data], csv_file_path)
repository.save_artifact(csv_file_path)
elif artifact_format == ArtifactFormat.JSON:
json_file_path = os.path.join(tmpdir, "wiki_sp500.json")
with open(json_file_path, "w") as json_file:
json.dump([d.model_dump() for d in data], json_file)
repository.save_artifact(json_file_path)
else:
raise NotImplementedError()
return repository.list_artifacts()[0].path
Expand Down
2 changes: 1 addition & 1 deletion tests/wiki_sp500/test_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def wiki_sp500_data():


@pytest.mark.parametrize(
"artifact_format", [None, ArtifactFormat.CSV, "INVALID_FORMAT"]
"artifact_format", [None, ArtifactFormat.CSV, ArtifactFormat.JSON, "INVALID_FORMAT"]
)
def test_save_local_artifact_repo(artifact_format, wiki_sp500_data):
store = WikiSP500Store()
Expand Down

0 comments on commit 219b904

Please sign in to comment.