Skip to content

switch lightning-cloud to lightning SDK #369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/extras.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ pillow
viztracer
pyarrow
tqdm
lightning-cloud == 0.5.70 # Must be pinned to ensure compatibility
lightning-sdk ==0.1.17 # Must be pinned to ensure compatibility
google-cloud-storage
1 change: 0 additions & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,5 @@ pytest-rerunfailures ==14.0
pytest-random-order ==1.1.1
pandas
lightning
lightning-cloud == 0.5.70 # Must be pinned to ensure compatibility
zstd
numpy < 2.0
1 change: 0 additions & 1 deletion src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
# This is required for full pytree serialization / deserialization support
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
_LIGHTNING_CLOUD_AVAILABLE = RequirementCache("lightning-cloud")
_BOTO3_AVAILABLE = RequirementCache("boto3")
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
_ZSTD_AVAILABLE = RequirementCache("zstd")
Expand Down
2 changes: 1 addition & 1 deletion src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,7 @@ def run(self, data_recipe: DataRecipe) -> None:
result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)

if num_nodes == node_rank + 1 and self.output_dir.url and self.output_dir.path is not None and _IS_IN_STUDIO:
from lightning_cloud.openapi import V1DatasetType
from lightning_sdk.lightning_cloud.openapi import V1DatasetType

_create_dataset(
input_dir=self.input_dir.path,
Expand Down
6 changes: 3 additions & 3 deletions src/litdata/processing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def _create_dataset(
if not storage_dir:
raise ValueError("The storage_dir should be defined.")

from lightning_cloud.openapi import ProjectIdDatasetsBody
from lightning_cloud.openapi.rest import ApiException
from lightning_cloud.rest_client import LightningClient
from lightning_sdk.lightning_cloud.openapi import ProjectIdDatasetsBody
from lightning_sdk.lightning_cloud.openapi.rest import ApiException
from lightning_sdk.lightning_cloud.rest_client import LightningClient

client = LightningClient(retry=False)

Expand Down
6 changes: 3 additions & 3 deletions src/litdata/streaming/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _match_studio(target_id: Optional[str], target_name: Optional[str], cloudspa


def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Optional[str]) -> Dir:
from lightning_cloud.rest_client import LightningClient
from lightning_sdk.lightning_cloud.rest_client import LightningClient

client = LightningClient(max_tries=2)

Expand Down Expand Up @@ -140,7 +140,7 @@ def _resolve_studio(dir_path: str, target_name: Optional[str], target_id: Option


def _resolve_s3_connections(dir_path: str) -> Dir:
from lightning_cloud.rest_client import LightningClient
from lightning_sdk.lightning_cloud.rest_client import LightningClient

client = LightningClient(max_tries=2)

Expand All @@ -162,7 +162,7 @@ def _resolve_s3_connections(dir_path: str) -> Dir:


def _resolve_datasets(dir_path: str) -> Dir:
from lightning_cloud.rest_client import LightningClient
from lightning_sdk.lightning_cloud.rest_client import LightningClient

client = LightningClient(max_tries=2)

Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def azure_mock(monkeypatch):

@pytest.fixture()
def lightning_cloud_mock(monkeypatch):
lightning_cloud = ModuleType("lightning_cloud")
monkeypatch.setitem(sys.modules, "lightning_cloud", lightning_cloud)
lightning_cloud = ModuleType("lightning_sdk.lightning_cloud")
monkeypatch.setitem(sys.modules, "lightning_sdk.lightning_cloud", lightning_cloud)
rest_client = ModuleType("rest_client")
monkeypatch.setitem(sys.modules, "lightning_cloud.rest_client", rest_client)
monkeypatch.setitem(sys.modules, "lightning_sdk.lightning_cloud.rest_client", rest_client)
lightning_cloud.rest_client = rest_client
rest_client.LightningClient = Mock()
return lightning_cloud
Expand Down
4 changes: 2 additions & 2 deletions tests/streaming/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from unittest import mock

import pytest
from lightning_cloud import login
from lightning_cloud.openapi import (
from lightning_sdk.lightning_cloud import login
from lightning_sdk.lightning_cloud.openapi import (
Externalv1Cluster,
V1AwsDataConnection,
V1AWSDirectV1,
Expand Down
Loading