Skip to content

Commit

Permalink
Cherrypick master (#331)
Browse files Browse the repository at this point in the history
* Update similarity_model.py

Update verbose printing to display the count of indexed items.

Verbose output was missing an f-string prefix and also returned the entire shape. Now we just return the number of examples.

* 0.17 patches (#325)

* fixes #323 Default indexer distance is now cosine in Sim Model.

Calling create_index method now defaults to cosine distance.

Additionally, auto distance defaults to cosine if no distance is passed to compile.

* fixes #322 remove all calls to tf.convert_to_tensor in SimModel.

* Update gitignore to exclude models and datasets from the example notebooks.

* Update multi-modal notebook to remove the call to compile.

* Patch bump

* Remove check for tf.shape in index. Input can also be tuple or dict, so we should use len() here.

* Update github workflow tests to use TF >= 2.8

* Tensor slice sampler (#329)

* Create tfdata_sampler.py

Initial version of new tf.data.Dataset sampler.

* Refactor and clean up the tf data sampler.

* Add initial tests for tfdata_sampler

* Reformat TFDataSampler test file.

* Fix proto dep issue in github workflow tests. py 3.10 breaks with protobuf > 3.20.x

* Setting env var didn't work. Trying again with pinning the protobuf version to 3.20.1

* Check TF version before creating the tf dataset counter.

* Format file

* Remove as_numpy_iterator when creating the list of grouped datasets.

* Also move class_list filter to before the group_by function
* Apply the total_examples_per_class as a take() function on each
  grouped dataset
* Remove as much casting as possible from the dataset. Certain functions
  expect an int64 though and require casting.

* Refactor to move the filter by class list out of the window_group_by function.

* Add class list filter test.

* Move augment_fn and load_fn to before the repeat and batch functions.

This change means the aug and load functions apply per example now. This
will make it easier to apply random augmentations per example and is
more consistent with how we implemented it in the existing memory
sampler.

This change also improves the tests for all parts of the module.

* Add support for handling tuple and dict values for y.

This change adds support for passing a callable to parse the correct
class id element for batch sampling. By default y is assumed to be a 1D
tensor with the class ids and the function is lambda y:y. Otherwise we
accept an int or str and construct a parser to get the class id tensor.

* Update email for github actions bot to fix CLA errors in PR

* Fix import order and remove typing imports

* Fix import check in search init.

* Small updates to tfdata_sampler doc string
  • Loading branch information
owenvallis authored May 5, 2023
1 parent 55632d3 commit 10480b3
Show file tree
Hide file tree
Showing 10 changed files with 2,149 additions and 1,639 deletions.
7 changes: 4 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ jobs:
matrix:
include:
- python-version: '3.7'
tf-version: '2.7'
tf-version: '2.8'
- python-version: '3.7'
tf-version: '2.11'
- python-version: '3.10'
# Python 3.10 only supports TF >= 2.8
tf-version: '2.8'
- python-version: '3.10'
tf-version: '2.11'
Expand All @@ -34,14 +33,16 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install coveralls
- name: Install dev packages
run: |
pip install ".[dev]"
- name: Install TF package
run: |
pip install tensorflow==${{ matrix.tf-version }}
# Fix proto dep issue in protobuf 4
pip install protobuf==3.20.*
- name: Lint with flake8
run: |
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ release.sh
benchmark/supervised/datasets/
benchmark/supervised/models/
datasets/
multi_modal_datasets/
*.h5

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
Loading

0 comments on commit 10480b3

Please sign in to comment.