forked from conglu1997/SynthER
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
33 lines (26 loc) · 1.08 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import gym
import numpy as np
from gym.wrappers.flatten_observation import FlattenObservation
from redq.algos.core import ReplayBuffer
def wrap_gym(env: gym.Env, rescale_actions: bool = True) -> gym.Env:
if rescale_actions:
env = gym.wrappers.RescaleAction(env, -1, 1)
if isinstance(env.observation_space, gym.spaces.Dict):
env = FlattenObservation(env)
env = gym.wrappers.ClipAction(env)
return env
# Make transition dataset from REDQ replay buffer.
def make_inputs_from_replay_buffer(
replay_buffer: ReplayBuffer,
model_terminals: bool = False,
) -> np.ndarray:
ptr_location = replay_buffer.ptr
obs = replay_buffer.obs1_buf[:ptr_location]
actions = replay_buffer.acts_buf[:ptr_location]
next_obs = replay_buffer.obs2_buf[:ptr_location]
rewards = replay_buffer.rews_buf[:ptr_location]
inputs = [obs, actions, rewards[:, None], next_obs]
if model_terminals:
terminals = replay_buffer.done_buf[:ptr_location].astype(np.float32)
inputs.append(terminals[:, None])
return np.concatenate(inputs, axis=1)