Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched inference CEBRA & padding at the Solver level #168

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
283de06
first proposal for batching in tranform method
gonlairo Jun 21, 2023
202e379
first running version of padding with batched inference
gonlairo Jun 22, 2023
1f1989d
start tests
gonlairo Jun 23, 2023
8665660
add pad_before_transform to fit function and add support for convolut…
gonlairo Sep 27, 2023
8d5b114
remove print statements
gonlairo Sep 27, 2023
32c5ecd
first passing test
gonlairo Sep 27, 2023
9928f63
add support for hybrid models
gonlairo Sep 28, 2023
be5630a
rewrite transform in sklearn API
gonlairo Sep 28, 2023
1300b20
baseline version of a torch.Datset
gonlairo Oct 16, 2023
bc6af24
move batching logic outside solver
gonlairo Oct 20, 2023
ec377b9
move functionality to base file in solver and separate in functions
gonlairo Oct 27, 2023
6f9ca98
add test_select_model for single session
gonlairo Oct 27, 2023
fbe7eb4
add checks and test for _process_batch
gonlairo Oct 27, 2023
463b0f8
add test_select_model for multisession
gonlairo Oct 30, 2023
5219171
make self.num_sessions compatible with single session training
gonlairo Oct 31, 2023
f9bd1a6
improve test_batched_transform_singlesession
gonlairo Nov 1, 2023
e23a7ef
make it work with small batches
gonlairo Nov 7, 2023
19c3f87
make test with multisession work
gonlairo Nov 8, 2023
87bebac
change to torch padding
gonlairo Nov 9, 2023
f0303e0
add argument to sklearn api
gonlairo Nov 9, 2023
8c8be85
add torch padding to _transform
gonlairo Nov 9, 2023
59df402
convert to torch if numpy array as inputs
gonlairo Nov 9, 2023
1aadc8b
add distinction between pad with data and pad with zeros and modify t…
gonlairo Nov 15, 2023
bc8ee25
differentiate between data padding and zero padding
gonlairo Nov 17, 2023
5e7a14c
remove float16
gonlairo Nov 24, 2023
928d882
change argument position
gonlairo Nov 27, 2023
07bac1c
clean test
gonlairo Nov 27, 2023
0823b54
clean test
gonlairo Nov 27, 2023
9fe3af3
Fix warning
CeliaBenquet Mar 26, 2024
b417a23
Improve modularity remove duplicate code and todos
CeliaBenquet Aug 21, 2024
83c1669
Add tests to solver
CeliaBenquet Aug 22, 2024
9c46eb9
Remove unused import in solver/utils
CeliaBenquet Aug 22, 2024
c845ec3
Fix test plot
CeliaBenquet Aug 22, 2024
9db3e37
Add some coverage
CeliaBenquet Aug 22, 2024
8e5f933
Fix save/load
CeliaBenquet Aug 22, 2024
d08e400
Remove duplicate configure_for in multi dataset
CeliaBenquet Aug 22, 2024
0c693dd
Make save/load cleaner
CeliaBenquet Aug 22, 2024
ae056b2
Merge branch 'main' into batched-inference-and-padding
CeliaBenquet Sep 18, 2024
794867b
Fix codespell errors
CeliaBenquet Sep 18, 2024
0bb6549
Fix docs compilation errors
CeliaBenquet Sep 18, 2024
04a102f
Fix formatting
CeliaBenquet Sep 18, 2024
7aab282
Fix extra docs errors
CeliaBenquet Sep 18, 2024
ffa66eb
Fix offset in docs
CeliaBenquet Sep 18, 2024
7f58607
Remove attribute ref
CeliaBenquet Sep 18, 2024
c2544c7
Add review updates
CeliaBenquet Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove float16
  • Loading branch information
gonlairo authored and stes committed Aug 23, 2024
commit 5e7a14c3cc80f3d35887a38cccb6a33b580bef3a
9 changes: 5 additions & 4 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,7 @@ def transform(self,
# Input validation
#TODO: if inputs are in cuda, then it throws an error, deal with this.
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
input_dtype = X.dtype
#input_dtype = X.dtype

if isinstance(X, np.ndarray):
X = torch.from_numpy(X)
Expand All @@ -1248,10 +1248,11 @@ def transform(self,
session_id=session_id,
batch_size=batch_size)

if input_dtype == "float64":
return output.astype(input_dtype)
#TODO: check if this is safe.
return output.numpy(force=True)

return output
#if input_dtype == "float64":
# return output.astype(input_dtype)

def fit_transform(
self,
Expand Down
3 changes: 2 additions & 1 deletion cebra/integrations/sklearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
X,
accept_sparse=False,
accept_large_sparse=False,
dtype=("float16", "float32", "float64"),
# NOTE: remove float16 because F.pad does not allow float16.
dtype=("float32", "float64"),
order=None,
copy=False,
force_all_finite=True,
Expand Down