Skip to content

Commit

Permalink
refactor: use None instead of literal none in hv compression
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Aug 6, 2024
1 parent 8500dba commit 269cf90
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
4 changes: 1 addition & 3 deletions open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ class HvConfig(BaseConfig):
announce_maddrs: list[str] | None = None
matchmaking_time: float | None = None
averaging_timeout: float | None = None
hivemind_compression: Literal["none", "fp16", "scaled-fp16", "uniform8bit", "quantile8bit", "blockwise8bit"] = (
"none"
)
hivemind_compression: Literal["fp16", "scaled-fp16", "uniform8bit", "quantile8bit", "blockwise8bit"] | None = None
all_reduce_strategy: AllReduceStrategy = AllReduceStrategy.WAIT_FOR_ALL
timeout_waiting_for_peers: float | None = None
skip_load_from_peers: bool = False
Expand Down
16 changes: 9 additions & 7 deletions open_diloco/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,17 @@ def hash_tensor_content(a: torch.Tensor, max_size: int = 1000) -> str:
return hashlib.md5(_round_flatten(a, max_size=max_size).encode("utf-8")).hexdigest()


def get_compression_kwargs(hivemind_compression: str) -> dict:
def get_compression_kwargs(hivemind_compression: str | None) -> dict:
"""Return the compression kwargs for hivemind optimizer based on the hivemind_compression argument."""
ret_kwargs = {}
if hivemind_compression == "fp16":

if hivemind_compression is None:
from hivemind import NoCompression

ret_kwargs["grad_compression"] = NoCompression()
ret_kwargs["state_averaging_compression"] = NoCompression()

elif hivemind_compression == "fp16":
from hivemind import Float16Compression

ret_kwargs["grad_compression"] = Float16Compression()
Expand All @@ -103,11 +110,6 @@ def get_compression_kwargs(hivemind_compression: str) -> dict:

ret_kwargs["grad_compression"] = ScaledFloat16Compression()
ret_kwargs["state_averaging_compression"] = ScaledFloat16Compression()
elif hivemind_compression == "none":
from hivemind import NoCompression

ret_kwargs["grad_compression"] = NoCompression()
ret_kwargs["state_averaging_compression"] = NoCompression()
elif hivemind_compression == "uniform8bit":
from hivemind import Uniform8BitQuantization

Expand Down

0 comments on commit 269cf90

Please sign in to comment.