Skip to content

Commit

Permalink
demo files are now saved and loaded recursively from state dictionary…
Browse files Browse the repository at this point in the history
… with torch.Tensor values
  • Loading branch information
tomato1mule committed Apr 25, 2023
1 parent 88b7668 commit 10f220d
Show file tree
Hide file tree
Showing 15 changed files with 194 additions and 49 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,15 @@ python pick_train.py
python place_train.py
```

If you want load already trained checkpoints, please rename 'checkpoint_example' folder to 'checkpoint'.
If you want to load already trained checkpoints, please rename 'checkpoint_example' folder to 'checkpoint'.
## Evaluate
Please run the example notebook codes for visualizing sampled poses from trained models (evaluate_pick.ipynb and evaluate_place.ipynb)

## View train log
```shell
python train_log_viewer.py --logdir="checkpoint/mug_10_demo/ {pick or place} /trainlog_iter_{iter}.gzip"
```




Binary file modified demo/test_demo/data/demo_0.gzip
Binary file not shown.
Binary file modified demo/test_demo/data/demo_1.gzip
Binary file not shown.
Binary file modified demo/test_demo/data/demo_2.gzip
Binary file not shown.
Binary file modified demo/test_demo/data/demo_3.gzip
Binary file not shown.
Binary file modified demo/test_demo/data/demo_4.gzip
Binary file not shown.
Binary file modified demo/test_demo/data/demo_5.gzip
Binary file not shown.
Binary file modified demo/test_demo/data/demo_6.gzip
Binary file not shown.
Binary file modified demo/test_demo/data/demo_7.gzip
Binary file not shown.
Binary file modified demo/test_demo/data/demo_8.gzip
Binary file not shown.
Binary file modified demo/test_demo/data/demo_9.gzip
Binary file not shown.
147 changes: 101 additions & 46 deletions edf/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import gzip
import pickle
from typing import Union, Optional, List, Tuple, Dict, Any, Iterable, TypeVar
from typing import Union, Optional, List, Tuple, Dict, Any, Iterable, TypeVar, Type
from abc import ABCMeta, abstractmethod
import warnings

Expand All @@ -23,27 +23,20 @@ def load_yaml(file_path: str):
config = yaml.load(file, Loader=yaml.FullLoader)
return config

def gzip_save(data, dir, filename):
def gzip_save(data, path: str):
dir = os.path.dirname(path)

if not os.path.exists(dir):
os.makedirs(dir)

with gzip.open(os.path.join(dir, filename), 'wb') as f:
with gzip.open(path, 'wb') as f:
pickle.dump(data, f)

def gzip_load(dir, filename):
with gzip.open(os.path.join(dir, filename), 'rb') as f:
def gzip_load(path: str):
with gzip.open(path, 'rb') as f:
data = pickle.load(f)
return data



# class SE3():
# pass

# class PointCloud():
# pass


class SE3():
def __init__(self, poses: Union[torch.Tensor, Iterable], device: Optional[Union[str, torch.device]] = None, renormalize: bool = True):
if not isinstance(poses, torch.Tensor):
Expand Down Expand Up @@ -237,15 +230,6 @@ def to_o3d(self) -> o3d.geometry.PointCloud:
pcd.colors = o3d.utility.Vector3dVector(self.colors.detach().cpu())

return pcd

@staticmethod
def from_pcd(pcd: o3d.geometry.PointCloud, device: Union[str, torch.device] = 'cpu') -> PointCloud:
warnings.warn("PointCloud.from_pcd() is deprecated.")
return PointCloud.from_o3d(pcd=pcd, device=device)

def to_pcd(self) -> o3d.geometry.PointCloud:
warnings.warn("PointCloud.to_pcd() is deprecated.")
return self.to_o3d()

def get_data_dict(self) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]:
data_dict = {"points": self.points.detach().cpu(),
Expand Down Expand Up @@ -301,7 +285,12 @@ def merge(*args) -> PointCloud:
return PointCloud(points=points, colors=colors)

@staticmethod
def points_to_plotly(pcd: Union[PointCloud, torch.Tensor], point_size: float = 1.0, name: Optional[str] = None, opacity: Union[float, torch.Tensor] = 1.0, colors: Optional[Iterable] = None, custom_data: Optional[Dict] = None) -> go.Scatter3d:
def points_to_plotly(pcd: Union[PointCloud, torch.Tensor],
point_size: float = 1.0,
name: Optional[str] = None,
opacity: Union[float, torch.Tensor] = 1.0,
colors: Optional[Iterable] = None,
custom_data: Optional[Dict] = None) -> go.Scatter3d:
if colors is not None:
colors = torch.tensor(colors)
if isinstance(pcd, PointCloud):
Expand Down Expand Up @@ -347,8 +336,28 @@ def points_to_plotly(pcd: Union[PointCloud, torch.Tensor], point_size: float = 1

return go.Scatter3d(**plotly_kwargs)

@staticmethod
def show_pcd(pcd: Union[PointCloud, torch.Tensor],
point_size: float = 1.0,
name: Optional[str] = None,
opacity: Union[float, torch.Tensor] = 1.0,
colors: Optional[Iterable] = None,
custom_data: Optional[Dict] = None,
width = 1600,
height = 1200,
) -> go.Figure:

data = PointCloud.points_to_plotly(pcd=pcd, point_size=point_size, name=name, opacity=opacity, colors=colors, custom_data=custom_data)
fig = go.Figure(data=[data], layout=dict(width=width, height=height))
return fig

def plotly(self, point_size: float = 1.0, name: Optional[str] = None, opacity: Union[float, torch.Tensor] = 1.0, colors: Optional[torch.Tensor] = None, custom_data: Optional[dict] = None) -> go.Scatter3d:
return PointCloud.points_to_plotly(pcd=self, point_size=point_size, name=name, opacity=opacity, colors=colors, custom_data=custom_data)

def show(self, point_size: float = 1.0, name: Optional[str] = None, opacity: Union[float, torch.Tensor] = 1.0, colors: Optional[torch.Tensor] = None, custom_data: Optional[dict] = None, width = 1600, height=1200):
return PointCloud.show_pcd(pcd=self, point_size=point_size, name=name, opacity=opacity, colors=colors, custom_data=custom_data, width=width, height=height)





Expand All @@ -360,12 +369,14 @@ def plotly(self, point_size: float = 1.0, name: Optional[str] = None, opacity: U


class Demo(metaclass=ABCMeta):
name: str = ''

@abstractmethod
def to(self, device: Union[str, torch.device]) -> Demo:
pass

@abstractmethod
def get_data(self) -> Dict[str, Any]:
def get_data_dict(self) -> Dict[str, Any]:
pass

@staticmethod
Expand All @@ -378,7 +389,10 @@ class TargetPoseDemo(Demo):
def __init__(self, target_poses: SE3,
scene_pc: Optional[PointCloud] = None,
grasp_pc: Optional[PointCloud] = None,
name: Optional[str] = None,
device: Optional[Union[str, torch.device]] = None):

self.name = name

if device is None:
device = target_poses.device
Expand All @@ -402,31 +416,57 @@ def to(self, device: Union[str, torch.device]) -> TargetPoseDemo:
if device == self.device:
return self
else:
return TargetPoseDemo(scene_pc=self.scene_pc.to(device=device), target_poses=self.target_poses.to(device=device), grasp_pc=self.grasp_pc.to(device=device), device=device)
return TargetPoseDemo(scene_pc=self.scene_pc.to(device=device), target_poses=self.target_poses.to(device=device), grasp_pc=self.grasp_pc.to(device=device), name=self.name, device=device)


def get_data(self) -> Dict:
data_dict = {}
def get_data_dict(self) -> Dict:
data_dict = {'name': self.name}
for k,v in self.scene_pc.get_data_dict().items():
data_dict["scene_"+k] = v
for k,v in self.grasp_pc.get_data_dict().items():
data_dict["grasp_"+k] = v
for k,v in self.target_poses.get_data_dict().items():
data_dict["target_"+k] = v
data_dict["demo_type"] = self.__class__
data_dict["demo_type"]: str = self.__class__.__name__
return data_dict

@staticmethod
def from_data_dict(data_dict: Dict, device: Union[str, torch.device] = 'cpu'):
assert data_dict["demo_type"] == TargetPoseDemo
def from_data_dict(data_dict: Dict, device: Union[str, torch.device] = 'cpu', rename: Union[bool, Optional[str]] = False):
assert data_dict["demo_type"] == TargetPoseDemo.__name__
scene_pc = PointCloud.from_data_dict(data_dict={"points": data_dict["scene_points"], "colors": data_dict["scene_colors"]}, device=device)
grasp_pc = PointCloud.from_data_dict(data_dict={"points": data_dict["grasp_points"], "colors": data_dict["grasp_colors"]}, device=device)
target_poses = SE3.from_data_dict(data_dict={"poses": data_dict["target_poses"]}, device=device)

return TargetPoseDemo(scene_pc=scene_pc, grasp_pc=grasp_pc, target_poses=target_poses, device=device)
if rename is False:
try:
name = data_dict['name']
except KeyError:
name = ''
elif rename is True:
raise ValueError(f"The 'rename' argument must be a boolean 'False', or 'None', or string, but {rename} is given.")
else:
assert (isinstance(rename, str) or rename is None), f"The 'rename' argument must be a boolean 'False', or 'None', or string, but {rename} is given."
name = rename

return TargetPoseDemo(scene_pc=scene_pc, grasp_pc=grasp_pc, target_poses=target_poses, name = name, device=device)

def __repr__(self) -> str:
string = f"TargetPoseDemo"
if self.name:
string = string + f" (name: {self.name})"
return string

def __str__(self) -> str:
string = f"TargetPoseDemo"
if self.name:
string = string + f" (name: {self.name})"
return string


class DemoSequence():
valid_demo_type = ['TargetPoseDemo']
str_to_demo_class = {'TargetPoseDemo': TargetPoseDemo}

def __init__(self, demo_seq: List[Demo] = [], device: Optional[Union[str, torch.device]] = None):
if device is None:
if demo_seq:
Expand All @@ -451,28 +491,49 @@ def to(self, device: Union[str, torch.device]) -> DemoSequence:

def __getitem__(self, idx) -> Union[Demo, List[Demo]]:
return self.demo_seq[idx]

def __len__(self) -> int:
return len(self.demo_seq)

def get_data_seq(self) -> List[Dict]:
return [demo.get_data() for demo in self.demo_seq]
return [demo.get_data_dict() for demo in self.demo_seq]

def save_data(self, dir: str, filename: str) -> bool:
gzip_save(data=self.get_data_seq(), dir=dir, filename=filename)
def save_data(self, path: str) -> bool:
gzip_save(data=self.get_data_seq(), path = path)
return True

@staticmethod
def from_data_seq(data_dict_seq: List[Dict], device: Union[str, torch.device] = 'cpu'):
demo_seq: List[Demo] = []
for data_dict in data_dict_seq:
demo_type: Demo = data_dict["demo_type"]
demo_type: str = data_dict["demo_type"]
if demo_type in DemoSequence.valid_demo_type:
demo_type: Type[Demo] = DemoSequence.str_to_demo_class[demo_type]
else:
raise ValueError(f"Unknown demo type: {demo_type}")
demo_seq.append(demo_type.from_data_dict(data_dict=data_dict, device=device))

return DemoSequence(demo_seq=demo_seq, device=device)

@staticmethod
def from_file(dir: str, filename: str, device: Union[str, torch.device] = 'cpu'):
data_dict_seq: List[Dict] = gzip_load(dir=dir, filename=filename)
def from_file(path: str, device: Union[str, torch.device] = 'cpu') -> DemoSequence:
data_dict_seq: List[Dict] = gzip_load(path=path)
return DemoSequence.from_data_seq(data_dict_seq=data_dict_seq, device=device)

def __repr__(self) -> str:
string = f"Demonstration sequence of length: {self.__len__()}"
for i in range(self.__len__()):
string += f"\n\tDemo {i}: {self.demo_seq[i].__str__()}"

return string

def __str__(self) -> str:
string = f"Demonstration sequence of length: {self.__len__()}"
for i in range(self.__len__()):
string += f"\n\tDemo {i}: {self.demo_seq[i].__str__()}"

return string


def save_demos(demos: List[DemoSequence], dir: str):
if not os.path.exists(dir):
Expand All @@ -482,15 +543,15 @@ def save_demos(demos: List[DemoSequence], dir: str):
for i, demo in enumerate(demos):
data_dir = "data"
filename = f"demo_{i}.gzip"
demo.save_data(dir=os.path.join(dir, data_dir), filename=filename)
demo.save_data(path=os.path.join(dir, data_dir, filename))
f.write("- \""+os.path.join(data_dir, filename)+"\"\n")

def load_demos(dir: str, annotation_file = "data.yaml") -> List[DemoSequence]:
files = load_yaml(file_path=os.path.join(dir, annotation_file))

demos: List[DemoSequence] = []
for file in files:
demos.append(DemoSequence.from_file(dir=dir, filename=file))
demos.append(DemoSequence.from_file(os.path.join(dir, file)))

return demos

Expand Down Expand Up @@ -524,10 +585,4 @@ def __getitem__(self, idx):
return data









2 changes: 1 addition & 1 deletion pick_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
pick_agent.save(agent_param_dir, filename)

log_filename = f'trainlog_iter_{iter}.gzip'
gzip_save(train_logs, dir=agent_param_dir, filename=log_filename)
gzip_save(train_logs, path=os.path.join(agent_param_dir, log_filename))

if verbose:
print("===============================", flush=True)
Expand Down
2 changes: 1 addition & 1 deletion place_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@
place_agent.save(agent_param_dir, filename)

log_filename = f'trainlog_iter_{iter}.gzip'
gzip_save(train_logs, dir=agent_param_dir, filename=log_filename)
gzip_save(train_logs, path=os.path.join(agent_param_dir, log_filename))

if verbose:
print("===============================", flush=True)
Expand Down
Loading

0 comments on commit 10f220d

Please sign in to comment.