Skip to content

Commit

Permalink
Merge pull request #43 from fairinternal/lep.start_embodied
Browse files Browse the repository at this point in the history
Starts implementing factor and variable types for 2D robot experiments
  • Loading branch information
luisenp authored Jun 17, 2021
2 parents c1831a3 + 2a6bf23 commit fec90d4
Show file tree
Hide file tree
Showing 29 changed files with 1,260 additions and 92 deletions.
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
max-line-length = 100
# E203: whitespace before ":", incompatible with black
# W503: line break before binary operator (black also)
# F401: imported but unused
ignore=E203, W503
per-file-ignores =
*__init__.py:F401

[mypy]
python_version = 3.7
Expand Down
1 change: 1 addition & 0 deletions theseus/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
EPS = 1e-10
42 changes: 33 additions & 9 deletions theseus/core/factor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import abc
import copy
from typing import Dict, Iterator, List, Optional, Tuple
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple

import torch

Expand All @@ -17,11 +16,11 @@
class Factor(abc.ABC):
def __init__(
self,
variables: List[Variable],
precision: Precision,
*args: Any,
name: Optional[str] = None,
**kwargs: Any,
):
self.variables = variables
self.precision = precision
if name:
self.name = name
Expand Down Expand Up @@ -61,13 +60,20 @@ def weighted_jacobians_error(
def __len__(self):
return len(self.variables)

def copy(self, new_name: Optional[str] = None) -> "Factor":
# Must copy everything
@abc.abstractmethod
def _copy_impl(self, new_name: Optional[str] = None) -> "Factor":
pass

def copy(
self, new_name: Optional[str] = None, keep_variable_names: bool = False
) -> "Factor":
if not new_name:
new_name = f"{self.name}_copy"
new_factor = copy.copy(self)
new_factor.name = new_name
new_factor.precision = self.precision.copy()
new_factor.variables = copy.deepcopy(self.variables)
new_factor = self._copy_impl(new_name=new_name)
if keep_variable_names:
for i, v in enumerate(new_factor.variables):
v.name = self.variables[i].name
return new_factor

def __deepcopy__(self, memo):
Expand All @@ -87,3 +93,21 @@ def to(self, *args, **kwargs):
self.precision.to(*args, **kwargs)
for var in self.variables:
var.to(*args, **kwargs)

@abc.abstractmethod
def _get_variables_impl(self) -> List[Variable]:
pass

@property
def variables(self) -> List[Variable]:
return self._get_variables_impl()

@variables.setter
def variables(self, variables: Sequence[Variable]):
for i, v in enumerate(variables):
self.set_variable_at(v, i)

# Sets the variable at the given index in the order returned by factor.variables
@abc.abstractmethod
def set_variable_at(self, variable: Variable, idx: int):
pass
8 changes: 4 additions & 4 deletions theseus/core/factor_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,15 @@ def copy(self) -> "FactorGraph":
new_graph = FactorGraph()
new_factors = []
for factor in self.factors.values():
new_variables = [v.copy(new_name=v.name) for v in factor.variables]
new_factor = factor.copy(new_name=factor.name)
new_factor.variables = new_variables
new_factor = factor.copy(new_name=factor.name, keep_variable_names=True)
new_factors.append(new_factor)

# Handle case where a variable is copied in 2+ factors, but only a single
# copy should be maintained by graph
for factor in new_factors:
for i, var in enumerate(factor.variables):
if new_graph.has_variable(var.name):
factor.variables[i] = new_graph.variables[var.name]
factor.set_variable_at(new_graph.variables[var.name], i)
new_graph.add(factor)
return new_graph

Expand Down
63 changes: 39 additions & 24 deletions theseus/core/precision.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import copy
from itertools import count
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple

Expand All @@ -8,7 +7,6 @@


# Abstract class for representing precision functions (inverse covariance)
# (equivalent to LossFunction in minisam).
# Concrete classes must implement two methods:
# - `weight_error`: return an error tensor weighted by the precision
# - `weightJacobiansError`: returns jacobians an errors weighted by the precision
Expand All @@ -21,6 +19,7 @@ def __init__(
param_name: Optional[str] = None,
):
self.data: torch.Tensor
self.learnable = learnable
if data is not None:
self.data = data
else:
Expand Down Expand Up @@ -55,10 +54,13 @@ def weight_jacobians_and_error(
) -> Tuple[List[torch.Tensor], torch.Tensor]:
pass

# Must copy everything
@abc.abstractmethod
def _copy_impl(self) -> "Precision":
pass

def copy(self) -> "Precision":
new_cov = copy.copy(self)
new_cov.data = self.data.clone()
return new_cov
return self._copy_impl()

def __deepcopy__(self, memo):
if id(self) in memo:
Expand All @@ -82,37 +84,42 @@ def update(self, input_batch: torch.Tensor):
pass


# Note that these operations are in-place
# (consider renaming weight functions as in minisam - or adding underscore suffix)
class ScalePrecision(Precision):
_ids = count(0)

def __init__(self, scale: float, learnable: bool = False):
def __init__(
self, scale: float, data: Optional[torch.Tensor] = None, learnable: bool = False
):
# TODO if we keep this _id mechanism, add test for this
_id = next(self._ids)
param_name = f"scale_cov_{_id}"
super().__init__(scale, learnable=learnable, param_name=param_name)
self.scale = scale
super().__init__(scale, data=data, learnable=learnable, param_name=param_name)

def _init_data(self, scale: float): # type: ignore
self.data = torch.tensor(scale)
self.data = torch.tensor(scale).float()

def weight_error(self, error: torch.Tensor) -> torch.Tensor:
error.mul_(self.data)
return error
return error * self.data

def weight_jacobians_and_error(
self,
jacobians: List[torch.Tensor],
error: torch.Tensor,
) -> Tuple[List[torch.Tensor], torch.Tensor]:
error.mul_(self.data)
error = error * self.data
new_jacobians = []
for jac in jacobians:
jac.data.mul_(self.data)
return jacobians, error
new_jacobians.append(jac * self.data)
return new_jacobians, error

def _copy_impl(self) -> "ScalePrecision":
new_prec = ScalePrecision(
self.scale, data=self.data.detach(), learnable=self.learnable
)
return new_prec


# Note that these operations are in-place
# (consider renaming weight functions as in minisam - or adding underscore suffix)
class DiagonalPrecision(Precision):
_ids = count(0)

Expand All @@ -125,28 +132,36 @@ def __init__(
# TODO if we keep this _id mechanism, add test for this
_id = next(self._ids)
param_name = f"diagonal_cov_{_id}"
super().__init__(diagonal, learnable=learnable, param_name=param_name)
self.diagonal = diagonal
super().__init__(
diagonal, data=data, learnable=learnable, param_name=param_name
)

def _init_data(self, diagonal: float): # type: ignore
if isinstance(diagonal, torch.Tensor):
self.data = diagonal
else:
self.data = torch.tensor(diagonal)
self.data = torch.tensor(diagonal).float()
if self.data.ndim != 1:
raise ValueError("DiagonalPrecision only accepts arrays of dim. 1.")

def weight_error(self, error: torch.Tensor) -> torch.Tensor:
error.mul_(self.data)
return error
return error * self.data

def weight_jacobians_and_error(
self,
jacobians: List[torch.Tensor],
error: torch.Tensor,
) -> Tuple[List[torch.Tensor], torch.Tensor]:
error.mul_(self.data)
error = error * self.data
new_jacobians = []
for jac in jacobians:
# Jacobian is batch_size x factor_dim x var_dim
# This left multiplies the weights (inv cov.) to jacobian
jac.data.mul_(self.data.view(1, -1, 1))
return jacobians, error
new_jacobians.append(jac * self.data.view(1, -1, 1))
return new_jacobians, error

def _copy_impl(self) -> "DiagonalPrecision":
return DiagonalPrecision(
self.diagonal, data=self.data.detach(), learnable=self.learnable
)
47 changes: 44 additions & 3 deletions theseus/core/tests/common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import copy

import torch

import theseus.core


class MockVar(theseus.core.Variable):
def __init__(self, length, data=None, name=None):
super().__init__(length, data=None, name=name)
super().__init__(length, data=data, name=name)

def _init_data(self, length):
self.data = torch.empty(1, length)
Expand All @@ -21,6 +23,9 @@ def _local(variable1, variable2):
def _retract(variable, delta):
pass

def _copy_impl(self, new_name=None):
return MockVar(self.data.shape[1], data=self.data.clone(), name=new_name)


class MockPrecision(theseus.core.Precision):
def __init__(self, the_data):
Expand All @@ -35,6 +40,9 @@ def weight_error(self, error):
def weight_jacobians_and_error(self, jacobians, error):
pass

def _copy_impl(self, new_name=None):
return MockPrecision(self.data.clone())


class NullPrecision(theseus.core.Precision):
def __init__(self):
Expand All @@ -49,22 +57,37 @@ def weight_error(self, error):
def weight_jacobians_and_error(self, jacobians, error):
return jacobians, error

def _copy_impl(self, new_name=None):
return NullPrecision()


class MockFactor(theseus.core.Factor):
def __init__(self, variables, precision, name=None):
super().__init__(variables, precision, name=name)
self._variables = variables
self._dim = 2
super().__init__(precision, name=name)

def error(self):
mu = torch.stack([v.data for v in self.variables]).sum()
return mu * torch.ones(self._dim)

def jacobians(self):
pass
return [self.error()] * len(self.variables)

def dim(self) -> int:
return self._dim

def _get_variables_impl(self):
return self._variables

def set_variable_at(self, variable, idx):
self._variables[idx] = variable

def _copy_impl(self, new_name=None):
return MockFactor(
[v.copy() for v in self._variables], self.precision.copy(), name=new_name
)


def create_mock_factors(data=None, precision=NullPrecision()):
len_data = 1 if data is None else data.shape[1]
Expand Down Expand Up @@ -92,3 +115,21 @@ def create_graph_with_mock_factors(data=None, precision=NullPrecision()):
graph.add(factor)

return graph, factors, names, var_to_factors


def check_copy_var(var):
var.name = "old"
new_var = var.copy(new_name="new")
assert var is not new_var
assert var.data is not new_var.data
assert torch.allclose(var.data, new_var.data)
assert new_var.name == "new"
new_var_no_name = copy.deepcopy(var)
assert new_var_no_name.name == f"{var.name}_copy"


def check_another_var_is_copy(var, other_var):
assert isinstance(var, other_var.__class__)
assert var is not other_var
assert var.data is not other_var.data
assert torch.allclose(var.data, other_var.data)
2 changes: 1 addition & 1 deletion theseus/core/tests/test_factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_default_name():
for var in range(num_vars):
name = str(np.random.randint(111111, high=1000000))
names.append(name)
variables.append(MockVar(1, 1, name=name))
variables.append(MockVar(1, torch.ones(1, 1), name=name))
factor_name = ".".join(names)
factor = MockFactor(variables, MockPrecision(torch.ones(1)))
assert factor.name == factor_name
Loading

0 comments on commit fec90d4

Please sign in to comment.