Skip to content

Commit

Permalink
add support for hjson config files (microsoft#2783)
Browse files Browse the repository at this point in the history
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
2 people authored and CodeSkull-1 committed Mar 6, 2023
1 parent fe5da3f commit 7244db0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
5 changes: 3 additions & 2 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
import json
import hjson
import copy
import base64

Expand Down Expand Up @@ -705,14 +706,14 @@ def __init__(self, config: Union[str, dict], mpu=None):
if isinstance(config, dict):
self._param_dict = config
elif os.path.exists(config):
self._param_dict = json.load(
self._param_dict = hjson.load(
open(config,
"r"),
object_pairs_hook=dict_raise_error_on_duplicate_keys)
else:
try:
config_decoded = base64.urlsafe_b64decode(config).decode('utf-8')
self._param_dict = json.loads(config_decoded)
self._param_dict = hjson.loads(config_decoded)
except (UnicodeDecodeError, AttributeError):
raise ValueError(
f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. Received: {config}"
Expand Down
34 changes: 33 additions & 1 deletion tests/unit/runtime/test_ds_config_dict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# A test on its own
import os
import torch
import pytest
import json
import hjson
import argparse

from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
Expand Down Expand Up @@ -158,11 +160,41 @@ def test_get_bfloat16_enabled(bf16_key):
assert get_bfloat16_enabled(cfg) == True


class TestConfigLoad(DistributedTest):
world_size = 1

def test_dict(self, base_config):
hidden_dim = 10
model = SimpleModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(config=base_config,
model=model,
model_parameters=model.parameters())

def test_json(self, base_config, tmpdir):
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, 'w') as fp:
json.dump(base_config, fp)
hidden_dim = 10
model = SimpleModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(config=config_path,
model=model,
model_parameters=model.parameters())

def test_hjson(self, base_config, tmpdir):
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, 'w') as fp:
hjson.dump(base_config, fp)
hidden_dim = 10
model = SimpleModel(hidden_dim)
model, _, _, _ = deepspeed.initialize(config=config_path,
model=model,
model_parameters=model.parameters())


class TestDeprecatedDeepScaleConfig(DistributedTest):
world_size = 1

def test(self, base_config, tmpdir):

config_path = create_config_from_dict(tmpdir, base_config)
parser = argparse.ArgumentParser()
args = parser.parse_args(args='')
Expand Down

0 comments on commit 7244db0

Please sign in to comment.