Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorization refactor #205

Merged
merged 27 commits into from
Jun 9, 2022
Merged
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
424c714
Created a wrapper cost function class that combines the aux vars for …
luisenp Jun 1, 2022
ee9e235
Disabled support for optimization variables in cost weights.
luisenp Jun 1, 2022
ea74465
Changed Objective to iterate over CFWrapper if available, and Theseus…
luisenp Jun 1, 2022
1b3af0b
Added a Vectorizer class and moved CFWrappers there.
luisenp Jun 1, 2022
2d3f9d2
Renamed vectorizer as Vectorize, added logic to replace Objective ite…
luisenp Jun 1, 2022
6a146cb
Added a CostFunctionSchema -> List[CostFunction] to use for vectoriza…
luisenp Jun 2, 2022
6c6a887
_CostFunctionWrapper is now meant to just store a cached value coming…
luisenp Jun 2, 2022
77ac280
Added code to automatically compute shared vars in Vectorize.
luisenp Jun 2, 2022
31237da
Changed vectorized costs construction to ensure that their weight is …
luisenp Jun 2, 2022
d30e1af
Implemented main cost function vectorization logic.
luisenp Jun 6, 2022
36e89c7
Updated bug that was causing detached gradients.
luisenp Jun 6, 2022
376e8ef
Fixed invalid check in theseus end-to-end unit tests.
luisenp Jun 6, 2022
ae6db18
Added unit test for schema and shared var computation.
luisenp Jun 6, 2022
0a2ee0a
Added a test to check that computed vectorized errors are correct.
luisenp Jun 6, 2022
58cee83
Moved vectorization update call to base linearization class.
luisenp Jun 7, 2022
7e60f87
Changed code to allow batch_size > 1 in shared variables.
luisenp Jun 7, 2022
399bb90
Fixed unit test and added call to Objective.update() in update_vector…
luisenp Jun 7, 2022
10cbf1c
Added new private iterator for vectorized costs.
luisenp Jun 7, 2022
10b208a
Replaced _register_vars_in_list with TheseusFunction.register_vars.
luisenp Jun 9, 2022
db5f366
Renamed vectorize_cost_fns kwarg as vectorize.
luisenp Jun 9, 2022
bb83db3
Added license headers.
luisenp Jun 9, 2022
1d0cd20
Small refactor.
luisenp Jun 9, 2022
e902924
Fixed bug that was preventing vectorized costs to work with to(). End…
luisenp Jun 9, 2022
0ec439f
Renamed the private Objective cost function iterator to _get_iterator().
luisenp Jun 9, 2022
aab9ead
Renamed kwarg in register_vars.
luisenp Jun 9, 2022
e57f310
Set vectorize=True for inverse kinematics and backward tests.
luisenp Jun 9, 2022
d6a434f
Remove lingering comments.
luisenp Jun 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Changed code to allow batch_size > 1 in shared variables.
  • Loading branch information
luisenp committed Jun 7, 2022
commit 7e60f879f56e1b7c92b69c8bdb686947013d0136
62 changes: 36 additions & 26 deletions theseus/core/vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _expand(tensor: torch.Tensor, size: int) -> torch.Tensor:
def _update_all_cost_fns_var_data(
cf_wrappers: List[_CostFunctionWrapper],
var_names: List[str],
batch_size: int,
objective_batch_size: int,
names_to_data: Dict[str, List[torch.Tensor]],
):
# Get all the data from individual variables
Expand All @@ -194,18 +194,31 @@ def _update_all_cost_fns_var_data(
name = var_names[var_idx]
if name in seen_vars:
continue
# If not shared variable, always append the data
# If the variable is shared only need data for one of the cost
# functions and we can just extend later to complete the vectorized
# batch
if Vectorize._SHARED_TOKEN not in name or name not in names_to_data:
# if not a shared variable, expand to batch size if needed
data = (
var.data
if (var.data.shape[0] > 1 or Vectorize._SHARED_TOKEN in name)
else Vectorize._expand(var.data, batch_size)
)
names_to_data[name].append(data)

# If the variable is shared and batch_size == 1, only need data for
# one of the cost functions and we can just extend later to complete
# the vectorized batch, so here we can do `continue`
var_batch_size = var.data.shape[0]
if (
name in names_to_data
and Vectorize._SHARED_TOKEN in name
and var_batch_size == 1
):
continue

# Otherwise, we need to append to the list of data tensors. since
# we cannot extend to a full batch w/o copying.

# If not a shared variable, expand to batch size if needed, because
# those we will copy, no matter what.
# For shared variables, just append, we will handle the expansion
# when updating the vectorized variable containers.
data = (
var.data
if (var_batch_size > 1 or Vectorize._SHARED_TOKEN in name)
else Vectorize._expand(var.data, objective_batch_size)
)
names_to_data[name].append(data)
seen_vars.add(name)

# Goes through the list of vectorized variables and updates their data with the
Expand All @@ -226,20 +239,17 @@ def _update_vectorized_vars(
var.update(names_to_data[name][0])
continue

if Vectorize._SHARED_TOKEN in name:
data = names_to_data[name][0]
if data.shape[0] > 1:
original_name = name[len(Vectorize._SHARED_TOKEN) :]
raise RuntimeError(
f"Cannot vectorize shared variables with "
f"batch size > 1, but variable named {original_name} has "
f"batch size = {data.shape[0]}. If this is unavoidable for a "
f"batch, consider setting the batch size of your problem to 1, "
f"or turning cost function vectorization off."
)
var.update(Vectorize._expand(data, batch_size * num_cost_fns))
all_var_data = names_to_data[name]
if Vectorize._SHARED_TOKEN in name and all_var_data[0].shape[0] == 1:
# In this case this is a shared variable, so all_var_data[i] is
# the same for any value of i. So, we can just expand to the full
# vectorized size. Sadly, this doesn't work if batch_size > 1
var_tensor = Vectorize._expand(
all_var_data[0], batch_size * num_cost_fns
)
else:
var.update(torch.cat(names_to_data[name], dim=0))
var_tensor = torch.cat(all_var_data, dim=0)
var.update(var_tensor)

# Computes the error of the vectorized cost function and distributes the error
# to the cost function wrappers. The list of wrappers must correspond to the
Expand Down