Skip to content

Commit

Permalink
Renamed Variable.data as Variable.tensor (#229)
Browse files Browse the repository at this point in the history
* Renamed Variable.data as Variable.tensor. Passes lint and core unit tests.

* Fixed embodied unit tests.

* Fixed geometry unit tests.

* Fixed optimizer unit tests.

* Fixed remaining unit tests.

* Fixed tutorials and example scripts.

* Fixed some bugs from rebasing.

* Fixed some lingering .data bugs.
  • Loading branch information
luisenp authored Jul 12, 2022
1 parent 53edd27 commit 5ffb436
Show file tree
Hide file tree
Showing 78 changed files with 822 additions and 722 deletions.
4 changes: 2 additions & 2 deletions examples/backward_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def generate_data(num_points=10, a=1.0, b=0.5, noise_factor=0.01):
def quad_error_fn(optim_vars, aux_vars):
a, b = optim_vars
x, y = aux_vars
est = a.data * x.data.square() + b.data
err = y.data - est
est = a.tensor * x.tensor.square() + b.tensor
err = y.tensor - est
return err


Expand Down
12 changes: 5 additions & 7 deletions examples/bundle_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
"truncated": th.BackwardMode.TRUNCATED,
}

# Smaller values} result in error
th.SO3.SO3_EPS = 1e-6

# Logger
log = logging.getLogger(__name__)
Expand All @@ -38,14 +36,14 @@ def print_histogram(
histogram = theg.ba_histogram(
cameras=[
theg.Camera(
th.SE3(data=var_dict[c.pose.name]),
th.SE3(tensor=var_dict[c.pose.name]),
c.focal_length,
c.calib_k1,
c.calib_k2,
)
for c in ba.cameras
],
points=[th.Point3(data=var_dict[pt.name]) for pt in ba.points],
points=[th.Point3(tensor=var_dict[pt.name]) for pt in ba.points],
observations=ba.observations,
)
for line in histogram.split("\n"):
Expand Down Expand Up @@ -79,7 +77,7 @@ def _clone(t_):
return t_.detach().cpu().clone()

results = {
"log_loss_radius": _clone(log_loss_radius.data),
"log_loss_radius": _clone(log_loss_radius.tensor),
"theseus_outputs": dict((s, _clone(t)) for s, t in theseus_outputs.items()),
"err_history": info.err_history, # type: ignore
"loss": loss_value,
Expand Down Expand Up @@ -187,8 +185,8 @@ def run(cfg: omegaconf.OmegaConf, results_path: pathlib.Path):
theseus_optim = th.TheseusLayer(optimizer)

# copy the poses/pts to feed them to each outer iteration
orig_poses = {cam.pose.name: cam.pose.data.clone() for cam in ba.cameras}
orig_points = {pt.name: pt.data.clone() for pt in ba.points}
orig_poses = {cam.pose.name: cam.pose.tensor.clone() for cam in ba.cameras}
orig_points = {pt.name: pt.tensor.clone() for pt in ba.points}

# Outer optimization loop
loss_radius_tensor = torch.nn.Parameter(torch.tensor([3.0], dtype=torch.float64))
Expand Down
4 changes: 2 additions & 2 deletions examples/motion_planning_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ def run_learning_loop(cfg):
)

collision_w = (
motion_planner.objective.aux_vars["collision_w"].data.mean().item()
motion_planner.objective.aux_vars["collision_w"].tensor.mean().item()
)
cost_eps = motion_planner.objective.aux_vars["cost_eps"].data.mean().item()
cost_eps = motion_planner.objective.aux_vars["cost_eps"].tensor.mean().item()
print("collision weight", collision_w)
print("cost_eps", cost_eps)
print("OBJECTIVE MEAN LOSS", epoch_mean_objective_loss)
Expand Down
6 changes: 3 additions & 3 deletions examples/pose_graph/pose_graph_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def main(cfg):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

inputs = {var.name: var.data for var in verts}
inputs = {var.name: var.tensor for var in verts}
optimizer.objective.update(inputs)

start_event.record()
Expand All @@ -76,10 +76,10 @@ def main(cfg):
objective.error_squared_norm().detach().cpu().numpy().sum() / 2
)
results["R"] = torch.cat(
[pose.data[:, :, :d].detach().cpu() for pose in verts]
[pose.tensor[:, :, :d].detach().cpu() for pose in verts]
).numpy()
results["t"] = torch.cat(
[pose.data[:, :, d].detach().cpu() for pose in verts]
[pose.tensor[:, :, d].detach().cpu() for pose in verts]
).numpy()

savemat(dataset_name + ".mat", results)
Expand Down
13 changes: 7 additions & 6 deletions examples/pose_graph/pose_graph_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@

def get_batch_data(pg_batch: theg.PoseGraphDataset, pose_indices: List[int]):
batch = {
pg_batch.poses[index].name: pg_batch.poses[index].data for index in pose_indices
pg_batch.poses[index].name: pg_batch.poses[index].tensor
for index in pose_indices
}
batch.update({pg_batch.poses[0].name + "__PRIOR": pg_batch.poses[0].data.clone()})
batch.update({pg_batch.poses[0].name + "__PRIOR": pg_batch.poses[0].tensor.clone()})
batch.update(
{edge.relative_pose.name: edge.relative_pose.data for edge in pg_batch.edges}
{edge.relative_pose.name: edge.relative_pose.tensor for edge in pg_batch.edges}
)
return batch

Expand Down Expand Up @@ -153,11 +154,11 @@ def main(cfg):
edges = edges_n
else:
for pose, pose_n in zip(poses, poses_n):
pose.data = torch.cat((pose.data, pose_n.data))
pose.tensor = torch.cat((pose.tensor, pose_n.tensor))

for edge, edge_n in zip(edges, edges_n):
edge.relative_pose.data = torch.cat(
(edge.relative_pose.data, edge_n.relative_pose.data)
edge.relative_pose.tensor = torch.cat(
(edge.relative_pose.tensor, edge_n.relative_pose.tensor)
)

# create (or load) dataset
Expand Down
5 changes: 2 additions & 3 deletions examples/pose_graph/pose_graph_g2o.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
file_path = "datasets/tinyGrid3D.g2o"
dtype = torch.float64

th.SO3.SO3_EPS = 1e-6

num_verts, verts, edges = theg.pose_graph.read_3D_g2o_file(file_path, dtype=dtype)

objective = th.Objective(dtype)

log_loss_radius = th.Vector(
data=torch.tensor([[0]], dtype=dtype), name="log_loss_radius"
tensor=torch.tensor([[0]], dtype=dtype), name="log_loss_radius"
)
loss_cls = th.HuberLoss

Expand Down Expand Up @@ -53,5 +52,5 @@

theseus_optim = th.TheseusLayer(optimizer)

inputs = {var.name: var.data for var in verts}
inputs = {var.name: var.tensor for var in verts}
theseus_optim.forward(inputs, optimizer_kwargs={"verbose": True})
17 changes: 10 additions & 7 deletions examples/pose_graph/pose_graph_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def print_histogram(
):
log.info(msg)
with torch.no_grad():
poses = [th.SE3(data=var_dict[pose.name]) for pose in pg.poses]
poses = [th.SE3(tensor=var_dict[pose.name]) for pose in pg.poses]
histogram = theg.pg_histogram(poses=poses, edges=pg.edges)
for line in histogram.split("\n"):
log.info(line)
Expand All @@ -63,17 +63,18 @@ def get_batch_data(
pg_batch: theg.PoseGraphDataset, pose_indices: List[int], gt_pose_indices: List[int]
):
batch = {
pg_batch.poses[index].name: pg_batch.poses[index].data for index in pose_indices
pg_batch.poses[index].name: pg_batch.poses[index].tensor
for index in pose_indices
}
batch.update({pg_batch.poses[0].name + "__PRIOR": pg_batch.poses[0].data.clone()})
batch.update({pg_batch.poses[0].name + "__PRIOR": pg_batch.poses[0].tensor.clone()})
batch.update(
{
pg_batch.gt_poses[index].name: pg_batch.gt_poses[index].data
pg_batch.gt_poses[index].name: pg_batch.gt_poses[index].tensor
for index in gt_pose_indices
}
)
batch.update(
{edge.relative_pose.name: edge.relative_pose.data for edge in pg_batch.edges}
{edge.relative_pose.name: edge.relative_pose.tensor for edge in pg_batch.edges}
)
return batch

Expand All @@ -85,8 +86,10 @@ def pose_loss(
loss: torch.Tensor = torch.zeros(
1, dtype=pose_vars[0].dtype, device=pose_vars[0].device
)
poses_batch = th.SE3(data=torch.cat([pose.data for pose in pose_vars]))
gt_poses_batch = th.SE3(data=torch.cat([gt_pose.data for gt_pose in gt_pose_vars]))
poses_batch = th.SE3(tensor=torch.cat([pose.tensor for pose in pose_vars]))
gt_poses_batch = th.SE3(
tensor=torch.cat([gt_pose.tensor for gt_pose in gt_pose_vars])
)
pose_loss = th.local(poses_batch, gt_poses_batch).norm(dim=1)
loss += pose_loss.sum()
return loss
Expand Down
6 changes: 3 additions & 3 deletions examples/se2_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@


def run(x1: LieGroup, x2: LieGroup, num_iters=10, use_lie_tangent=True):
x1.data = LieGroupTensor(x1)
x1.data.requires_grad = True
x1.tensor = LieGroupTensor(x1)
x1.tensor.requires_grad = True

optim = torch.optim.Adam([x1.data], lr=1e-1)
optim = torch.optim.Adam([x1.tensor], lr=1e-1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optim, milestones=[250, 600], gamma=0.01
)
Expand Down
4 changes: 2 additions & 2 deletions examples/state_estimation_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def cost_weights_model():
th.Difference(
poses[i],
gps_cost_weights[i],
th.Point2(data=gps_targets_[i]),
th.Point2(tensor=gps_targets_[i]),
name=f"gps_{i}",
)
)
Expand All @@ -251,7 +251,7 @@ def cost_weights_model():
poses[i],
poses[i + 1],
between_cost_weights[i],
th.Point2(data=measurements_[i]),
th.Point2(tensor=measurements_[i]),
name=f"between_{i}",
)
)
Expand Down
20 changes: 10 additions & 10 deletions theseus/core/cost_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def __call__(

# The error function is assumed to receive variables in the format
# err_fn(
# optim_vars=(optim_vars[0].data, ..., optim_vars[N - 1].data),
# aux_vars=(aux_vars[0].data, ..., aux_vars[M -1].data)
# optim_vars=(optim_vars[0].tensor, ..., optim_vars[N - 1].tensor),
# aux_vars=(aux_vars[0].tensor, ..., aux_vars[M -1].tensor)
# )
#
# The user also needs to explicitly specify the output's dimension
Expand Down Expand Up @@ -144,10 +144,10 @@ def __init__(
self._tmp_aux_vars_for_loop = tuple(v.copy() for v in aux_vars)

for i, optim_var in enumerate(optim_vars):
self._tmp_optim_vars_for_loop[i].update(optim_var.data)
self._tmp_optim_vars_for_loop[i].update(optim_var.tensor)

for i, aux_var in enumerate(aux_vars):
self._tmp_aux_vars_for_loop[i].update(aux_var.data)
self._tmp_aux_vars_for_loop[i].update(aux_var.tensor)

self._autograd_loop_over_batch = autograd_loop_over_batch

Expand All @@ -169,9 +169,9 @@ def error(self) -> torch.Tensor:
def _make_jac_fn(
self, tmp_optim_vars: Tuple[Manifold, ...], tmp_aux_vars: Tuple[Variable, ...]
) -> Callable:
def jac_fn(*optim_vars_data_):
assert len(optim_vars_data_) == len(tmp_optim_vars)
for i, tensor in enumerate(optim_vars_data_):
def jac_fn(*optim_vars_tensors_):
assert len(optim_vars_tensors_) == len(tmp_optim_vars)
for i, tensor in enumerate(optim_vars_tensors_):
tmp_optim_vars[i].update(tensor)

return self._err_fn(optim_vars=tmp_optim_vars, aux_vars=tmp_aux_vars)
Expand All @@ -197,10 +197,10 @@ def jacobians(self) -> Tuple[List[torch.Tensor], torch.Tensor]:
jacobians_raw_loop: List[Tuple[torch.Tensor, ...]] = []
for n in range(optim_vars[0].shape[0]):
for i, aux_var in enumerate(aux_vars):
self._tmp_aux_vars_for_loop[i].update(aux_var.data[n : n + 1])
self._tmp_aux_vars_for_loop[i].update(aux_var.tensor[n : n + 1])

jacobians_n = self._compute_autograd_jacobian(
tuple(v.data[n : n + 1] for v in optim_vars),
tuple(v.tensor[n : n + 1] for v in optim_vars),
self._make_jac_fn(
self._tmp_optim_vars_for_loop, self._tmp_aux_vars_for_loop
),
Expand All @@ -213,7 +213,7 @@ def jacobians(self) -> Tuple[List[torch.Tensor], torch.Tensor]:
]
else:
jacobians_raw = self._compute_autograd_jacobian(
tuple(v.data for v in optim_vars),
tuple(v.tensor for v in optim_vars),
self._make_jac_fn(self._tmp_optim_vars, aux_vars),
)
aux_idx = torch.arange(err.shape[0]) # batch_size
Expand Down
34 changes: 18 additions & 16 deletions theseus/core/cost_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,25 @@ def __init__(
self.scale = Variable(scale)
else:
self.scale = scale
if not self.scale.data.squeeze().ndim in [0, 1]:
raise ValueError("ScaleCostWeight only accepts 0- or 1-dim (batched) data.")
self.scale.data = self.scale.data.view(-1, 1)
if not self.scale.tensor.squeeze().ndim in [0, 1]:
raise ValueError(
"ScaleCostWeight only accepts 0- or 1-dim (batched) tensors."
)
self.scale.tensor = self.scale.tensor.view(-1, 1)
self.register_aux_vars(["scale"])

def weight_error(self, error: torch.Tensor) -> torch.Tensor:
return error * self.scale.data
return error * self.scale.tensor

def weight_jacobians_and_error(
self,
jacobians: List[torch.Tensor],
error: torch.Tensor,
) -> Tuple[List[torch.Tensor], torch.Tensor]:
error = error * self.scale.data
error = error * self.scale.tensor
new_jacobians = []
for jac in jacobians:
new_jacobians.append(jac * self.scale.data.view(-1, 1, 1))
new_jacobians.append(jac * self.scale.tensor.view(-1, 1, 1))
return new_jacobians, error

def _copy_impl(self, new_name: Optional[str] = None) -> "ScaleCostWeight":
Expand All @@ -103,32 +105,32 @@ def __init__(
self.diagonal = Variable(diagonal)
else:
self.diagonal = diagonal
if not self.diagonal.data.squeeze().ndim < 3:
raise ValueError("DiagonalCostWeight only accepts data with ndim < 3.")
if self.diagonal.data.ndim == 0:
self.diagonal.data = self.diagonal.data.view(1, 1)
if self.diagonal.data.ndim == 1:
if not self.diagonal.tensor.squeeze().ndim < 3:
raise ValueError("DiagonalCostWeight only accepts tensors with ndim < 3.")
if self.diagonal.tensor.ndim == 0:
self.diagonal.tensor = self.diagonal.tensor.view(1, 1)
if self.diagonal.tensor.ndim == 1:
warnings.warn(
"1-D diagonal input is ambiguous. Dimension will be "
"interpreted as data dimension and not batch dimension."
"interpreted as dof dimension and not batch dimension."
)
self.diagonal.data = self.diagonal.data.view(1, -1)
self.diagonal.tensor = self.diagonal.tensor.view(1, -1)
self.register_aux_vars(["diagonal"])

def weight_error(self, error: torch.Tensor) -> torch.Tensor:
return error * self.diagonal.data
return error * self.diagonal.tensor

def weight_jacobians_and_error(
self,
jacobians: List[torch.Tensor],
error: torch.Tensor,
) -> Tuple[List[torch.Tensor], torch.Tensor]:
error = error * self.diagonal.data
error = error * self.diagonal.tensor
new_jacobians = []
for jac in jacobians:
# Jacobian is batch_size x cost_fuction_dim x var_dim
# This left multiplies the weights to jacobian
new_jacobians.append(jac * self.diagonal.data.unsqueeze(2))
new_jacobians.append(jac * self.diagonal.tensor.unsqueeze(2))
return new_jacobians, error

def _copy_impl(self, new_name: Optional[str] = None) -> "DiagonalCostWeight":
Expand Down
Loading

0 comments on commit 5ffb436

Please sign in to comment.