-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbase_task.py
More file actions
106 lines (84 loc) · 3.9 KB
/
base_task.py
File metadata and controls
106 lines (84 loc) · 3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# https://github.com/leggedrobotics/legged_gym/blob/master/legged_gym/envs/base/legged_robot_config.py
import sys
from isaacgym import gymapi
from isaacgym import gymutil
import numpy as np
import torch
from rsl_rl.env import VecEnv
# Base class for RL tasks
class BaseTask(VecEnv):
def __init__(self, cfg, sim_params, physics_engine, sim_device, headless):
self.gym = gymapi.acquire_gym()
self.sim_params = sim_params
self.physics_engine = physics_engine
self.sim_device = sim_device
sim_device_type, self.sim_device_id = gymutil.parse_device_str(self.sim_device)
self.headless = headless
if sim_device_type=='cuda' and sim_params.use_gpu_pipeline:
self.device = self.sim_device
else:
self.device = 'cpu'
self.graphics_device_id = self.sim_device_id
if self.headless == True:
self.graphics_device_id = -1
self.num_envs = cfg.env.num_envs
self.num_obs = cfg.env.num_observations
self.num_privileged_obs = cfg.env.num_privileged_obs
self.num_actions = cfg.env.num_actions
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
self.obs_buf = torch.zeros(self.num_envs, self.num_obs, device=self.device, dtype=torch.float)
self.rew_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.float)
self.reset_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
self.episode_length_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.long)
self.time_out_buf = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool)
if self.num_privileged_obs is not None:
self.privileged_obs_buf = torch.zeros(self.num_envs, self.num_privileged_obs, device=self.device, dtype=torch.float)
else:
self.privileged_obs_buf = None
# self.num_privileged_obs = self.num_obs
self.extras = {}
# create envs, sim and viewer
self.create_sim()
self.gym.prepare_sim(self.sim)
self.enable_viewer_sync = True
self.viewer = None
if self.headless == False:
self.viewer = self.gym.create_viewer(
self.sim, gymapi.CameraProperties())
self.gym.subscribe_viewer_keyboard_event(
self.viewer, gymapi.KEY_ESCAPE, "QUIT")
self.gym.subscribe_viewer_keyboard_event(
self.viewer, gymapi.KEY_V, "toggle_viewer_sync")
def get_observations(self):
return self.obs_buf
def get_privileged_observations(self):
return self.privileged_obs_buf
def reset_idx(self, env_ids):
"""Reset selected robots"""
raise NotImplementedError
def reset(self):
""" Reset all robots"""
self.reset_idx(torch.arange(self.num_envs, device=self.device))
obs, privileged_obs = self.get_observations(), None
return obs, privileged_obs
def step(self, actions):
raise NotImplementedError
def render(self, sync_frame_time=True):
if self.viewer:
if self.gym.query_viewer_has_closed(self.viewer):
sys.exit()
for evt in self.gym.query_viewer_action_events(self.viewer):
if evt.action == "QUIT" and evt.value > 0:
sys.exit()
elif evt.action == "toggle_viewer_sync" and evt.value > 0:
self.enable_viewer_sync = not self.enable_viewer_sync
if self.device != 'cpu':
self.gym.fetch_results(self.sim, True)
if self.enable_viewer_sync:
self.gym.step_graphics(self.sim)
self.gym.draw_viewer(self.viewer, self.sim, True)
if sync_frame_time:
self.gym.sync_frame_time(self.sim)
else:
self.gym.poll_viewer_events(self.viewer)