-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathinference.py
More file actions
34 lines (27 loc) · 1.03 KB
/
inference.py
File metadata and controls
34 lines (27 loc) · 1.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import sys
from isaacgym import gymapi
from rsl_rl.runners import OnPolicyRunner
from utils.helpers import class_to_dict, get_log_dir
from mtss_gym.mtss_gym import MotionTrackingFromSparseSensor
from mtss_cfg.mtss_cfg import MtssCfg, MtssPPOCfg
if __name__ == "__main__":
model_path = sys.argv[1]
# load environment
env_cfg = MtssCfg()
env_cfg.env.num_envs = 1
env_cfg.motion.dir = "resources/motions"
env_cfg.motion.files = os.listdir(env_cfg.motion.dir)
env = MotionTrackingFromSparseSensor(env_cfg, gymapi.SIM_PHYSX, "cuda", False)
# load model
rl_cfg = MtssPPOCfg()
rl_cfg.runner.resume = True
ppo_runner = OnPolicyRunner(env, class_to_dict(MtssPPOCfg), get_log_dir("logs", "mtss"), device="cuda")
ppo_runner.load(model_path)
policy = ppo_runner.get_inference_policy(device=env.device)
# run
env.reset()
obs = env.get_observations()
while True:
action = policy(obs.detach())
obs, _, rews, dones, infos = env.step(action.detach())