Skip to content

Commit

Permalink
Fix GPU index parsing (chaiNNer-org#2165)
Browse files Browse the repository at this point in the history
* Fix GPU index parsing

* Update backend/src/packages/chaiNNer_ncnn/settings.py

Co-authored-by: Michael Schmidt <mitchi5000.ms@googlemail.com>

* Update backend/src/packages/chaiNNer_onnx/settings.py

Co-authored-by: Michael Schmidt <mitchi5000.ms@googlemail.com>

* Update backend/src/packages/chaiNNer_pytorch/settings.py

Co-authored-by: Michael Schmidt <mitchi5000.ms@googlemail.com>

---------

Co-authored-by: Michael Schmidt <mitchi5000.ms@googlemail.com>
  • Loading branch information
joeyballentine and RunDevelopment authored Aug 30, 2023
1 parent bff1cae commit 1e891d2
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 4 deletions.
4 changes: 3 additions & 1 deletion backend/src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,10 @@ def get_bool(self, key: str, default: bool) -> bool:
return value
raise ValueError(f"Invalid bool value for {key}: {value}")

def get_int(self, key: str, default: int) -> int:
def get_int(self, key: str, default: int, parse_str: bool = False) -> int:
value = self.__settings.get(key, default)
if parse_str and isinstance(value, str):
return int(value)
if isinstance(value, int) and not isinstance(value, bool):
return value
raise ValueError(f"Invalid str value for {key}: {value}")
Expand Down
2 changes: 1 addition & 1 deletion backend/src/packages/chaiNNer_ncnn/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ def get_settings() -> NcnnSettings:
settings = package.get_settings()

return NcnnSettings(
gpu_index=settings.get_int("gpu_index", 0),
gpu_index=settings.get_int("gpu_index", 0, parse_str=True),
)
2 changes: 1 addition & 1 deletion backend/src/packages/chaiNNer_onnx/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_settings() -> OnnxSettings:
os.makedirs(tensorrt_cache_path)

return OnnxSettings(
gpu_index=settings.get_int("gpu_index", 0),
gpu_index=settings.get_int("gpu_index", 0, parse_str=True),
execution_provider=settings.get_str("execution_provider", default_provider),
tensorrt_cache_path=tensorrt_cache_path,
tensorrt_fp16_mode=settings.get_bool("tensorrt_fp16_mode", False),
Expand Down
2 changes: 1 addition & 1 deletion backend/src/packages/chaiNNer_pytorch/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,5 @@ def get_settings() -> PyTorchSettings:
return PyTorchSettings(
use_cpu=settings.get_bool("use_cpu", False),
use_fp16=settings.get_bool("use_fp16", False),
gpu_index=settings.get_int("gpu_index", 0),
gpu_index=settings.get_int("gpu_index", 0, parse_str=True),
)

0 comments on commit 1e891d2

Please sign in to comment.