-
Notifications
You must be signed in to change notification settings - Fork 90
Closed
Description
Is there an existing issue for this?
- I have searched the existing issues
Bug description
If you define a function on the callback of the method adapt_fit
along with a frequency, an error will show up mentioning that there is no logdir
specified to save checkpoints on the save
method of the solver.base
. Example:
progress_bar = lambda num_steps, solver: print(f"Embedding... {num_steps}")
U = embedding.fit_transform(X, adapt=True, callback_frequency=100, callback=progress_bar)
Results in:
Traceback (most recent call last):
File "/project/.venv/lib/python3.12/site-packages/streamlit/runtime/scriptrunner/exec_code.py", line 88, in exec_func_with_error_handling
result = func()
^^^^^^
File "/project/.venv/lib/python3.12/site-packages/streamlit/runtime/scriptrunner/script_runner.py", line 579, in code_to_exec
exec(code, module.__dict__)
File "/project/app.py", line 119, in <module>
main()
File "/project/app.py", line 78, in main
U = embedding.fit_transform(X, adapt=True, callback_frequency=100, callback=progress_bar)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project/.venv/lib/python3.12/site-packages/sklearn/utils/_set_output.py", line 157, in wrapped
data_to_wrap = f(self, X, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project/.venv/lib/python3.12/site-packages/cebra/integrations/sklearn/cebra.py", line 1289, in fit_transform
self.fit(X,
File "/project/.venv/lib/python3.12/site-packages/cebra/integrations/sklearn/cebra.py", line 1188, in fit
self._adapt_fit(X,
File "/project/.venv/lib/python3.12/site-packages/cebra/integrations/sklearn/cebra.py", line 1127, in _adapt_fit
self._partial_fit(*self.state_,
File "/project/.venv/lib/python3.12/site-packages/cebra/integrations/sklearn/cebra.py", line 1036, in _partial_fit
solver.fit(
File "/project/.venv/lib/python3.12/site-packages/cebra/solver/base.py", line 214, in fit
self.save(logdir, f"checkpoint_{num_steps:#07d}.pth")
File "/project/.venv/lib/python3.12/site-packages/cebra/solver/base.py", line 346, in save
if not os.path.exists(os.path.dirname(logdir)):
^^^^^^^^^^^^^^^^^^^^^^^
File "<frozen posixpath>", line 181, in dirname
TypeError: expected str, bytes or os.PathLike object, not NoneType
Following the error message, the call to solver.fit
is setting logdir = None
, which is then passed to base.save
, which then tried to save in None.
Operating System
Kubuntu 24.04
CEBRA version
0.4.0
Device type
GPU
Code of Conduct
- I agree to follow this project's Code of Conduct
Metadata
Metadata
Assignees
Labels
No labels