Skip to content

Commit 74e08e5

Browse files
committed
v1.5.2
1 parent 5940bd1 commit 74e08e5

File tree

4 files changed

+17
-6
lines changed

4 files changed

+17
-6
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050
setup(
5151
name="accelerate",
52-
version="1.5.1",
52+
version="1.5.2",
5353
description="Accelerate",
5454
long_description=open("README.md", encoding="utf-8").read(),
5555
long_description_content_type="text/markdown",

src/accelerate/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
__version__ = "1.5.1"
14+
__version__ = "1.5.2"
1515

1616
from .accelerator import Accelerator
1717
from .big_modeling import (

src/accelerate/data_loader.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def __init__(self, *args, **kwargs):
8989

9090
def __iter__(self):
9191
if self.generator is None:
92-
self.generator = torch.Generator(device=torch.get_default_device())
92+
self.generator = torch.Generator(
93+
device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
94+
)
9395
self.generator.manual_seed(self.initial_seed)
9496

9597
# Allow `self.epoch` to modify the seed of the generator
@@ -1156,13 +1158,19 @@ def prepare_data_loader(
11561158
data_source=sampler.data_source,
11571159
replacement=sampler.replacement,
11581160
num_samples=sampler._num_samples,
1159-
generator=getattr(sampler, "generator", torch.Generator(device=torch.get_default_device())),
1161+
generator=getattr(
1162+
sampler,
1163+
"generator",
1164+
torch.Generator(device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"),
1165+
),
11601166
data_seed=data_seed,
11611167
)
11621168

11631169
if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA:
11641170
# isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled.
1165-
generator = torch.Generator(device=torch.get_default_device()).manual_seed(42)
1171+
generator = torch.Generator(
1172+
device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
1173+
).manual_seed(42)
11661174
dataloader.generator = generator
11671175
dataloader.sampler.generator = generator
11681176
# No change if no multiprocess
@@ -1181,7 +1189,9 @@ def prepare_data_loader(
11811189
else:
11821190
if not use_seedable_sampler and hasattr(sampler, "generator"):
11831191
if sampler.generator is None:
1184-
sampler.generator = torch.Generator(device=torch.get_default_device())
1192+
sampler.generator = torch.Generator(
1193+
device=torch.get_default_device() if hasattr(torch, "get_default_device") else "cpu"
1194+
)
11851195
synchronized_generator = sampler.generator
11861196
batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler
11871197
new_batch_sampler = BatchSamplerShard(

src/accelerate/test_utils/testing.py

+1
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,7 @@ def run_first(test_case):
536536
"""
537537
if is_pytest_available():
538538
import pytest
539+
539540
return pytest.mark.order(1)(test_case)
540541
return test_case
541542

0 commit comments

Comments
 (0)