Skip to content

Commit

Permalink
improve the module importer's robustness
Browse files Browse the repository at this point in the history
This fixes the following issues:

- The old code inserted the "current working dir" into the module search paths, which has no correlation to the OneTrainer directory if the user executes the scripts with another active working directory. Now we use the actual OneTrainer directory regardless of its location, by leveraging the excellent Path library.

- The previous code inserted our path *last*, which meant that every import prioritized all other paths before looking in OneTrainer. This could lead to shadowing issues if a user has a system where a path contains another module named "modules", such as their own homemade project. We now insert ourselves at the top, as the highest priority.

- `scripts/install_zluda.py` was duplicating the efforts of importing modules. It clearly did that to avoid trying to load ZLUDA before it has been installed, but that's better handled by a loader flag instead.

- `scripts/sample.py` was permanently broken, since it attempted to import a module before fixing the import path.
  • Loading branch information
Arcitec committed Oct 21, 2024
1 parent d9291ce commit 4f4e8e4
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
7 changes: 4 additions & 3 deletions scripts/install_zluda.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import sys
from util.import_util import script_imports

script_imports(allow_zluda=False)

sys.path.append(os.getcwd())
import sys

from modules.zluda import ZLUDAInstaller

Expand Down
3 changes: 1 addition & 2 deletions scripts/sample.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from modules.util.ModelNames import ModelNames

from util.import_util import script_imports

script_imports()
Expand All @@ -9,6 +7,7 @@
from modules.util.config.SampleConfig import SampleConfig
from modules.util.enum.ImageFormat import ImageFormat
from modules.util.enum.TrainingMethod import TrainingMethod
from modules.util.ModelNames import ModelNames
from modules.util.torch_util import default_device


Expand Down
12 changes: 8 additions & 4 deletions scripts/util/import_util.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
def script_imports():
def script_imports(allow_zluda: bool = True):
import logging
import os
import sys
from pathlib import Path

# filter out the triton warning on startup
# Filter out the Triton warning on startup.
logging \
.getLogger("xformers") \
.addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

sys.path.append(os.getcwd())
# Insert ourselves as the highest-priority library path, so our modules are
# always found without any risk of being shadowed by another import path.
onetrainer_lib_path = Path(__file__).absolute().parent.parent.parent
sys.path.insert(0, str(onetrainer_lib_path))

if sys.platform.startswith('win'):
if allow_zluda and sys.platform.startswith('win'):
from modules.zluda import ZLUDAInstaller

zluda_path = ZLUDAInstaller.get_path()
Expand Down

0 comments on commit 4f4e8e4

Please sign in to comment.