-
Notifications
You must be signed in to change notification settings - Fork 3k
[Newton] Add Warp based inhand_manipulation env #4413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev/newton
Are you sure you want to change the base?
[Newton] Add Warp based inhand_manipulation env #4413
Conversation
1. con success 10.67 2. used the same MJWarpSolver as torch env
|
Attaching performance data Performance summary (torch → warp)
Δ% is computed as ((\text{warp} - \text{torch}) / \text{torch} \times 100%), so negative time Δ% = less time (better). |
Greptile OverviewGreptile SummaryThis PR implements a Warp-accelerated in-hand manipulation environment for the Allegro Hand robot, enabling high-performance parallel simulation across 8192 environments using GPU kernels. Key changes:
Architecture:
Minor issues found:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Gym
participant InHandManipulationWarpEnv
participant DirectRLEnvWarp
participant WarpKernels
participant Hand as Articulation (Hand)
participant Object as Articulation (Object)
User->>Gym: register environment
Gym->>InHandManipulationWarpEnv: create with AllegroHandWarpEnvCfg
InHandManipulationWarpEnv->>DirectRLEnvWarp: __init__(cfg)
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _setup_scene()
InHandManipulationWarpEnv->>Hand: create Articulation(robot_cfg)
InHandManipulationWarpEnv->>Object: create Articulation(object_cfg)
InHandManipulationWarpEnv->>WarpKernels: initialize_rng_state()
InHandManipulationWarpEnv->>WarpKernels: initialize_goal_constants()
InHandManipulationWarpEnv->>WarpKernels: initialize_xyz_unit_vecs()
User->>InHandManipulationWarpEnv: step(actions)
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _pre_physics_step(actions)
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _apply_action()
InHandManipulationWarpEnv->>WarpKernels: apply_actions_to_targets()
InHandManipulationWarpEnv->>Hand: set_joint_position_target()
InHandManipulationWarpEnv->>DirectRLEnvWarp: simulate physics
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _get_dones()
InHandManipulationWarpEnv->>WarpKernels: compute_intermediate_values()
InHandManipulationWarpEnv->>WarpKernels: get_dones()
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _get_observations()
InHandManipulationWarpEnv->>WarpKernels: compute_full_observations()
InHandManipulationWarpEnv->>WarpKernels: sanitize_and_print_once()
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _get_rewards()
InHandManipulationWarpEnv->>WarpKernels: compute_rewards()
InHandManipulationWarpEnv->>WarpKernels: update_consecutive_successes_from_stats()
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _reset_target_pose()
InHandManipulationWarpEnv->>WarpKernels: reset_target_pose()
alt Reset Required
InHandManipulationWarpEnv->>InHandManipulationWarpEnv: _reset_idx(mask)
InHandManipulationWarpEnv->>WarpKernels: reset_object()
InHandManipulationWarpEnv->>Object: update root_link_pose_w
InHandManipulationWarpEnv->>WarpKernels: reset_hand()
InHandManipulationWarpEnv->>Hand: update joint_pos/joint_vel
InHandManipulationWarpEnv->>WarpKernels: reset_successes()
end
InHandManipulationWarpEnv-->>User: obs, reward, done, info
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
| def rotation_distance(object_rot: wp.quatf, target_rot: wp.quatf) -> wp.float32: | ||
| # Orientation alignment for the cube in hand and goal cube | ||
| quat_diff = quat_mul(object_rot, quat_conjugate(target_rot)) | ||
| # Match Torch env convention: uses indices [1:4] for the vector part (see `rotation_distance` in Torch env). | ||
| v_norm = wp.sqrt(quat_diff[1] * quat_diff[1] + quat_diff[2] * quat_diff[2] + quat_diff[3] * quat_diff[3]) | ||
| v_norm = wp.min(v_norm, wp.float32(1.0)) | ||
| return wp.float32(2.0) * wp.asin(v_norm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vector part indexing in comment inconsistent with implementation
The comment says "uses indices [1:4]" but the implementation correctly uses indices 1, 2, 3 (which is [1:4) in Python slicing). The comment should clarify this is the xyz components (indices 1,2,3) not including the w component (index 3 in some conventions).
| def rotation_distance(object_rot: wp.quatf, target_rot: wp.quatf) -> wp.float32: | |
| # Orientation alignment for the cube in hand and goal cube | |
| quat_diff = quat_mul(object_rot, quat_conjugate(target_rot)) | |
| # Match Torch env convention: uses indices [1:4] for the vector part (see `rotation_distance` in Torch env). | |
| v_norm = wp.sqrt(quat_diff[1] * quat_diff[1] + quat_diff[2] * quat_diff[2] + quat_diff[3] * quat_diff[3]) | |
| v_norm = wp.min(v_norm, wp.float32(1.0)) | |
| return wp.float32(2.0) * wp.asin(v_norm) | |
| # Orientation alignment for the cube in hand and goal cube | |
| quat_diff = quat_mul(object_rot, quat_conjugate(target_rot)) | |
| # Match Torch env convention: uses xyz components (indices 1, 2, 3) for the vector part (see `rotation_distance` in Torch env). | |
| v_norm = wp.sqrt(quat_diff[1] * quat_diff[1] + quat_diff[2] * quat_diff[2] + quat_diff[3] * quat_diff[3]) | |
| v_norm = wp.min(v_norm, wp.float32(1.0)) | |
| return wp.float32(2.0) * wp.asin(v_norm) |
| # unit vectors | ||
| self.x_unit_vecs = wp.zeros(self.num_envs, dtype=wp.vec3f, device=self.device) | ||
| self.y_unit_vecs = wp.zeros(self.num_envs, dtype=wp.vec3f, device=self.device) | ||
| self.z_unit_vecs = wp.zeros(self.num_envs, dtype=wp.vec3f, device=self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
z_unit_vecs initialized but never used
The z-axis unit vector is initialized but never referenced in any kernel or method. Consider removing it or adding a comment explaining why it's reserved for future use.
AntoineRichard
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of nits around the warp code. Otherwise it looks good to me.
| @wp.func | ||
| def quat_mul(q1: wp.quatf, q2: wp.quatf) -> wp.quatf: | ||
| # Hamilton product for quaternions in (x, y, z, w). | ||
| x1, y1, z1, w1 = q1[0], q1[1], q1[2], q1[3] | ||
| x2, y2, z2, w2 = q2[0], q2[1], q2[2], q2[3] | ||
| x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 | ||
| y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2 | ||
| z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2 | ||
| w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 | ||
| return wp.quatf(x, y, z, w) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is already native support for that in warp. q1*q2 does it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed
| def quat_conjugate(q: wp.quatf) -> wp.quatf: | ||
| return wp.quatf(-q[0], -q[1], -q[2], q[3]) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is already warp support for that: wp.quat_inverse()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed
| @wp.func | ||
| def quat_from_angle_axis(angle: wp.float32, axis: wp.vec3f) -> wp.quatf: | ||
| # axis assumed to be unit-length in this task. | ||
| half = angle * wp.float32(0.5) | ||
| s = wp.sin(half) | ||
| c = wp.cos(half) | ||
| return wp.quatf(axis[0] * s, axis[1] * s, axis[2] * s, c) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is already a function for that: quat_from_axis_angle
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
addressed
| wp.launch( | ||
| initialize_xyz_unit_vecs, | ||
| dim=self.num_envs, | ||
| inputs=[ | ||
| self.x_unit_vecs, | ||
| self.y_unit_vecs, | ||
| self.z_unit_vecs, | ||
| ], | ||
| device=self.device, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be replaced by
self.x_unit_vec = wp.vec3f(1.0,0.0,0.0)
self.y_unit_vec = wp.vec3f(0.0,1.0,0.0)
self.z_unit_vec = wp.vec3f(0.0,0.0,1.0)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The torch env do have per env unit vector for some reason. Confirming if that's unnecessary now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
| x_unit_vecs: wp.array(dtype=wp.vec3f), | ||
| y_unit_vecs: wp.array(dtype=wp.vec3f), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These do not need to be arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
| object_pos: wp.array(dtype=wp.vec3f), | ||
| object_rot: wp.array(dtype=wp.quatf), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These could be fed as object_pose and a transformf. The risk is that if the tensor if not contiguous we are launching kernels to split the poses.
| object_linvel: wp.array(dtype=wp.vec3f), | ||
| object_angvel: wp.array(dtype=wp.vec3f), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These could be fed as a spatial vector directly. The risk is that if the tensor if not contiguous we are launching kernels to split the velocities.
| object_pos: wp.array(dtype=wp.vec3f), | ||
| object_rot: wp.array(dtype=wp.quatf), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here could be a transform directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Albeit there is no much gain since there is no transform ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea. In this env, it looks like there's no direct operation on transformf. Converting to a transform type would require extracting pos and rot everywhere, which might be less convenient. Do you think it's required?
| target_pos: wp.array(dtype=wp.vec3f), | ||
| target_rot: wp.array(dtype=wp.quatf), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could also be a transform
| float(self.cfg.dist_reward_scale), | ||
| float(self.cfg.rot_reward_scale), | ||
| float(self.cfg.rot_eps), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this float conversion needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
Performance summary (
|
| Metric | Torch | Warp | Δ% (torch→warp) |
|---|---|---|---|
| Action processing mean (us, N=9600) | 1502.42 | 36.77 | -97.55% |
| Newton simulation mean (us, N=9600) | 17604.23 | 17201.77 | -2.29% |
| Post-processing mean (us, N=4800) | 6708.68 | 168.72 | -97.49% |
| Total step mean (us, N=4800) | 45722.28 | 35529.03 | -22.29% |
| Throughput (steps/s) | 70185 | 85147 | +21.32% |
| Iteration time (s) | 0.93 | 0.77 | -17.20% |
| Collection time (s) | 0.762 | 0.600 | -21.26% |
| Learning time (s) | 0.171 | 0.170 | -0.58% |
Δ% is computed as (((\text{warp}-\text{torch})/\text{torch})\times 100%), so negative time Δ% = less time (better).
The potential joint sampling issue is summarized in #4404
It does seem that updating sampling, which previous included in the performance stats, puts initial configuration into a harder case such that Newton simulation takes more time. Removing the sampling fix puts warp and torch into the same condition for comparision.
Description
Add warp env for inhand_manipulation
Fixes # (issue)
Type of change
Screenshots
Please attach before and after screenshots of the change if applicable.
Checklist
pre-commitchecks with./isaaclab.sh --formatconfig/extension.tomlfileCONTRIBUTORS.mdor my name already exists there