diff --git a/policy/agent.py b/policy/agent.py new file mode 100644 index 0000000..8fb21bc --- /dev/null +++ b/policy/agent.py @@ -0,0 +1,124 @@ +import numpy as np + + +class BlackJackAgent: + def __init__(self, method, env, function='V', gamma=0.99, epsilon=0.1): + self.method = method + self.values = {(i, j, b): 0 for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False]} + self.vreturns = {(i, j, b): [] for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False]} + self.qs = {(i, j, b, a): 10 for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False] for a in range(env.action_space.n)} + self.qreturns = {(i, j, b, a): [] for i in range(env.observation_space[0].n) for j in range(env.observation_space[1].n) for b in [True, False] for a in range(env.action_space.n)} + self.value_function = lambda i, j, k: self.values[(i, j, k)] + self.q_function = lambda i, j, k, l: self.qs[(i, j, k, l)] + self.get_state_name = lambda state: (state[0], state[1], state[2]) + self.get_state_action_name = lambda state, action: (state[0], state[1], state[2], action) + self.gamma = gamma + self.actions = list(range(env.action_space.n)) + self.policy = {state: 0 for state in self.values.keys()} + self.epsilon = epsilon + self.function = function + + def choose_action(self, state): + sum_, show, ace = state + if self.method == 'lucky': + return self.feeling_lucky(sum_) + if self.method == 'egreedy': + return self.epsilon_greedy(state) + + def epsilon_greedy(self, state): + if np.random.random() < self.epsilon: + return np.random.choice(self.actions) + else: + state_name = self.get_state_name(state) + return self.policy[state_name] + + def feeling_lucky(self, sum_): + if sum_ < 20: + return 1 + return 0 + + def update(self, rewards, states, actions, function='V'): + visited = set() + if self.function == 'V': + for i, state in enumerate(states): + state_name = self.get_state_name(state) + if state_name in visited: + continue + G = 0 + for j, reward in enumerate(rewards[i:], 1): + G += self.gamma ** j * reward + self.vreturns[state_name].append(G) + self.values[state_name] = np.mean(self.vreturns[state_name]) + visited.add(state_name) + elif self.function == 'Q': + for i, (state, action) in enumerate(zip(states, actions)): + state_action_name = self.get_state_action_name(state, action) + if state_action_name in visited: + continue + G = 0 + for j, reward in enumerate(rewards[i:], 1): + G += self.gamma ** j * reward + self.qreturns[state_action_name].append(G) + self.qs[state_action_name] = np.mean(self.qreturns[state_action_name]) + visited.add(state_action_name) + for state in states: + Q_prime, A_prime = -np.inf, None + for action in actions: + state_action_name = self.get_state_action_name(state, action) + curr_Q = self.qs[state_action_name] + if curr_Q > Q_prime: + Q_prime = curr_Q + A_prime = action + state_name = self.get_state_name(state) + self.policy[state_name] = A_prime + else: + raise NotImplementedError + + +class CartPoleNoob: + def __init__(self, method, env, function='V', alpha=0.1, gamma=0.99, epsilon=0.1, n_bins=10): + self.method = method + self.alpha = alpha + self.gamma = gamma + self.epsilon = epsilon + self.function = function + self.actions = list(range(env.action_space.n)) + self.rad = np.linspace(-0.2094, 0.2094, n_bins) + self.values = {r: 0 for r in range(len(self.rad) + 1)} + self.qs = {(r, a): 10 for r in range(len(self.rad) + 1) for a in self.actions} + + def choose_action(self, state): + if self.method == 'naive': + return self.naive_action(state) + if self.method == 'egreedy': + return self.epsilon_greedy(state) + + def naive_action(self, state): + if state[2] < 0: + return 0 + return 1 + + def epsilon_greedy(self, state): + if np.random.random() < self.epsilon: + return np.random.choice(self.actions) + else: + s = self.get_bucket_index([state[2]])[0] + action = np.array([self.qs[(s, a)] for a in self.actions]).argmax() + return action + + def get_bucket_index(self, states): + inds = np.digitize(states, self.rad) + return inds + + def update(self, state, action, reward, state_): + r, r_ = self.get_bucket_index([state[2], state_[2]]) + if self.function == 'V': + # TD update w/ bootstrap + self.values[r] += self.alpha * (reward + self.gamma * self.values[r_] - self.values[r]) + elif self.function == 'Q': + Q_ = np.array([self.qs[(r_, a)] for a in self.actions]).max() + self.qs[(r, action)] += self.alpha * (reward + self.gamma * Q_ - self.qs[(r, action)]) + self.decrease_eps() + + def decrease_eps(self): + self.epsilon = max(0.01, self.epsilon - 1e-5) diff --git a/policy/blackjack/main.py b/policy/blackjack/main.py new file mode 100644 index 0000000..4cda350 --- /dev/null +++ b/policy/blackjack/main.py @@ -0,0 +1,34 @@ +import gym +import argparse +from tqdm import trange +from policy.agent import BlackJackAgent + + +parser = argparse.ArgumentParser(description='Black Jack Agents') +parser.add_argument('--method', type=str, default='lucky', help='The name of the policy you wish to evaluate') +parser.add_argument('--function', type=str, default='Q', help='The function to evaluate') +parser.add_argument('--n_episodes', type=int, default=500000, help='Number of episodes you wish to run for') +args = parser.parse_args() + + +def first_visit_monte_carlo(): + env = gym.make('Blackjack-v0') + agent = BlackJackAgent(args.method, env, args.function) + for _ in trange(args.n_episodes): + state, done = env.reset(), False + states, actions, rewards = [state], [], [] + while not done: + action = agent.choose_action(state) + state_, reward, done, _ = env.step(action) + states.append(state) + rewards.append(reward) + actions.append(action) + state = state_ + agent.update(rewards, states, actions) + + print(agent.value_function(21, 2, True)) + print(agent.q_function(16, 2, False, 0)) + + +if __name__ == '__main__': + first_visit_monte_carlo() diff --git a/policy/cartpole/main.py b/policy/cartpole/main.py new file mode 100644 index 0000000..1d07c19 --- /dev/null +++ b/policy/cartpole/main.py @@ -0,0 +1,27 @@ +import gym +import argparse +from tqdm import trange +from policy.agent import CartPoleNoob + + +parser = argparse.ArgumentParser(description='Cartpole Agents') +parser.add_argument('--method', type=str, default='egreedy', help='The name of the policy you wish to evaluate') +parser.add_argument('--function', type=str, default='Q', help='The function to evaluate') +parser.add_argument('--n_episodes', type=int, default=500000, help='Number of episodes you wish to run for') +args = parser.parse_args() + + +def td(): + env = gym.make('CartPole-v0') + agent = CartPoleNoob(args.method, env, args.function) + for _ in trange(args.n_episodes): + state, done = env.reset(), False + while not done: + action = agent.choose_action(state) + state_, reward, done, _ = env.step(action) + agent.update(state, action, reward, state_) + state = state_ + print(agent.values) + +if __name__ == '__main__': + td() diff --git a/qlearning/agent.py b/qlearning/agent.py index b303a53..32625c6 100644 --- a/qlearning/agent.py +++ b/qlearning/agent.py @@ -3,7 +3,8 @@ import torch from torch import optim from copy import deepcopy -from qlearning.networks import QNaive, QBasic, QDueling +from qlearning import networks +from qlearning.networks import QNaive from qlearning.experience_replay import ReplayBuffer @@ -95,48 +96,40 @@ def decrease_epsilon(self): class DQNAgent(BaseAgent): def __init__(self, *args, **kwargs): super().__init__(*args) - self.algorithm = kwargs['algorithm'] - self.batch_size = kwargs['batch_size'] - self.grad_clip = kwargs['grad_clip'] - self.prioritize = kwargs['prioritize'] - self.alpha = kwargs['alpha'] - self.beta = kwargs['beta'] - self.eps = kwargs['eps'] - self.memory = ReplayBuffer(kwargs['max_size'], self.state_dim) - self.target_update_interval = kwargs['target_update_interval'] + for k, v in kwargs.items(): + setattr(self, k, v) + self.memory = ReplayBuffer(self.max_size, self.state_dim) self.n_updates = 0 - self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - if self.algorithm.startswith('Dueling'): - self.Q_function = QDueling( - kwargs['input_channels'], - self.n_actions, - kwargs['cpt_dir'], - kwargs['algorithm'] + '_' + kwargs['env_name'], - kwargs['img_size'], - kwargs['hidden_dim'], - noised=kwargs['noised']).to(self.device) - else: - self.Q_function = QBasic( - kwargs['input_channels'], - self.n_actions, - kwargs['cpt_dir'], - kwargs['algorithm'] + '_' + kwargs['env_name'], - kwargs['img_size'], - kwargs['hidden_dim'], - noised=kwargs['noised']).to(self.device) + + network = self.algorithm + if 'DD' in network: + import re + network = re.sub('DDQN', 'DQN', network) + network = getattr(networks, network) + self.Q_function = network( + input_channels=self.input_channels, + out_features=self.n_actions, + cpt_dir=self.cpt_dir, + name=self.algorithm + '_' + self.env_name, + img_size=self.img_size, + hidden_dim=self.hidden_dim, + n_repeats=self.n_repeats, + noised=self.noised, + num_atoms=self.num_atoms).to(self.device) # instanciate target network self.target_Q = deepcopy(self.Q_function) self.freeze_network(self.target_Q) - self.target_Q.name = kwargs['algorithm'] + '_' + kwargs['env_name'] + '_target' + self.target_Q.name = self.algorithm + '_' + self.env_name + '_target' self.optimizer = torch.optim.RMSprop(self.Q_function.parameters(), lr=self.lr, alpha=0.95) self.criterion = torch.nn.MSELoss(reduction='none') def greedy_action(self, observation): - observation = torch.tensor(observation, dtype=torch.float32).unsqueeze(0).to(self.device) - next_action = self.Q_function(observation).argmax() + with torch.no_grad(): + observation = torch.tensor(observation, dtype=torch.float32).unsqueeze(0).to(self.device) + next_action = self.Q_function(observation).argmax() return next_action.item() def update_target_network(self): @@ -161,18 +154,18 @@ def update(self): # double DQN uses online network to select action for Q' if self.algorithm.endswith('DDQN'): next_actions = self.Q_function(next_observations).argmax(-1) - q_prime = self.target_Q(next_observations)[list(range(self.batch_size)), next_actions] + q_prime = self.target_Q(next_observations).gather(1, next_actions.unsqueeze(1)) elif self.algorithm.endswith('DQN'): q_prime = self.target_Q(next_observations).max(-1)[0] # calculate target + estimate - q_target = rewards + self.gamma * q_prime * (~dones) - q_pred = self.Q_function(observations)[list(range(self.batch_size)), actions] - loss = self.criterion(q_target.detach(), q_pred) + q_target = rewards + self.gamma * q_prime.squeeze() * (~dones) + q_pred = self.Q_function(observations).gather(1, actions.unsqueeze(1)) + loss = self.criterion(q_target.detach(), q_pred.squeeze()) # for updating priorities if using priority replay if self.prioritize: - priorities = (idx, loss.clone().detach() + self.eps) + priorities = (idx, loss.detach().cpu() + self.eps) else: priorities = None @@ -182,13 +175,16 @@ def update(self): if self.grad_clip is not None: torch.nn.utils.clip_grad_norm_(self.Q_function.parameters(), self.grad_clip) self.optimizer.step() - self.decrease_epsilon() + self.adjust_epsilon_and_beta() self.n_updates += 1 if self.n_updates % self.target_update_interval == 0: self.update_target_network() return priorities - def decrease_epsilon(self): + def adjust_epsilon_and_beta(self): + self.beta = min( + self.beta_min, + self.beta + self.beta_dec) self.epsilon = max( self.epsilon_min, self.epsilon - self.epsilon_desc) @@ -198,7 +194,8 @@ def store_transition(self, state, reward, action, next_state, done, priority=Non self.memory.store(state, reward, action, next_state, done, priority=priority) def sample_transitions(self): - return self.memory.sample(self.batch_size, self.device) + transition = self.memory.sample(self.batch_size, self.device, self.beta) + return transition def save_models(self): self.target_Q.check_point() diff --git a/qlearning/atari/main.py b/qlearning/atari/main.py index 9f4e8ac..732d462 100644 --- a/qlearning/atari/main.py +++ b/qlearning/atari/main.py @@ -25,16 +25,17 @@ AtlantisNoFrameskip-v4\n \ BankHeistNoFrameskip-v4\n \ FlappyBird-v0') -parser.add_argument('--n_repeats', type=int, default=4, help='The number of repeated actions') +parser.add_argument('--n_repeats', type=int, default=4, help='Frames stack size') +parser.add_argument('--action_repeats', type=int, default=4, help='The number of repeated actions') parser.add_argument('--img_size', type=int, default=84, help='The height and width of images after resizing') parser.add_argument('--input_channels', type=int, default=1, help='The input channels after preprocessing') parser.add_argument('--hidden_dim', type=int, default=512, help='The hidden size for second fc layer') -parser.add_argument('--max_size', type=int, default=100000, help='Buffer size') -parser.add_argument('--target_update_interval', type=int, default=1000, help='Interval for updating target network') +parser.add_argument('--max_size', type=int, default=300000, help='Buffer size') # training parser.add_argument('-e', '--n_episodes', '--epochs', type=int, default=1000, help='Number of episodes agent interacts with env') parser.add_argument('--lr', type=float, default=0.00025, help='Learning rate') +parser.add_argument('--target_update_interval', type=int, default=1000, help='Interval for updating target network') parser.add_argument('--gamma', type=float, default=0.99, help='Discount factor') parser.add_argument('--epsilon_init', type=float, default=1.0, help='Initial epsilon value') parser.add_argument('--epsilon_min', type=float, default=0.1, help='Minimum epsilon value to decay to') @@ -43,9 +44,12 @@ parser.add_argument('-b', '--batch_size', type=int, default=32, help='Batch size') parser.add_argument('--no_prioritize', action="store_true", default=False, help='Use Prioritized Experience Replay') parser.add_argument('--alpha', type=float, default=0.6, help='Prioritized Experience Replay alpha') -parser.add_argument('--beta', type=float, default=0.4, help='Prioritized Experience Replay beta') +parser.add_argument('--beta_init', type=float, default=0.4, help='Initial beta value') +parser.add_argument('--beta_min', type=float, default=1.0, help='Maximum beta value to grow to') +parser.add_argument('--beta_dec', type=float, default=1e-5, help='Beta increase') parser.add_argument('--eps', type=float, default=1e-5, help='Prioritized Experience Replay epsilon') parser.add_argument('--noised', action="store_true", default=False, help='Using noisy networks') +parser.add_argument('--num_atoms', type=int, default=51, help='Number of atoms used for Categorical DQN') # logging parser.add_argument('--progress_window', type=int, default=100, help='Window of episodes for progress') @@ -65,7 +69,7 @@ if __name__ == '__main__': - env = processed_atari(args.env_name, args.img_size, args.input_channels, args.n_repeats) + env = processed_atari(args.env_name, args.img_size, args.input_channels, args.n_repeats, args.action_repeats) # if testing agent and want to output videos, make dir & wrap env to auto output video files if args.test and args.video: @@ -76,7 +80,7 @@ # force some parameters depending on if using priority replay, following paper protocols if args.no_prioritize: - args.alpha, args.beta, args.epsilon = 1, 0, 0 + args.alpha, args.beta_init, args.epsilon = 1, 0, 0 else: args.lr /= 4 @@ -99,9 +103,13 @@ grad_clip=args.grad_clip, prioritize=not args.no_prioritize, alpha=args.alpha, - beta=args.beta, + beta=args.beta_init, + beta_min=args.beta_min, + beta_dec=args.beta_dec, eps=args.eps, noised=args.noised, + n_repeats=args.n_repeats, + num_atoms=args.num_atoms, env_name=args.env_name) # load weights & make sure model in eval mode during test, only need online network for testings @@ -109,7 +117,7 @@ agent.load_models() agent.Q_function.eval() - scores, best_score = deque(maxlen=args.progress_window), -np.inf + scores, best_score, best_avg = deque(maxlen=args.progress_window), -np.inf, -np.inf pbar = tqdm(range(args.n_episodes)) for e in pbar: # reset every episode and make sure functions are in training mode @@ -135,8 +143,10 @@ writer.add_scalars('Performance and training', {'Score': score, 'Epsilon': agent.epsilon}) scores.append(score) avg_score = np.mean(scores) - if avg_score > best_score and not args.test: + if avg_score > best_avg and not args.test: agent.save_models() - best_score = avg_score + best_avg = avg_score + if score > best_score: + best_score = score if (e + 1) % args.print_every == 0: - tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, Average Score: {avg_score}, Best Score {best_score}, Epsilon: {agent.epsilon}') + tqdm.write(f'Episode: {e + 1}/{args.n_episodes}, Average Score: {avg_score}, Best Score {best_score}, Epsilon: {agent.epsilon}, Beta: {agent.beta}') diff --git a/qlearning/atari/utils.py b/qlearning/atari/utils.py index b6e8d6e..29ee4b8 100644 --- a/qlearning/atari/utils.py +++ b/qlearning/atari/utils.py @@ -84,9 +84,9 @@ def observation(self, observation): self.stack.append(observation) return np.array(self.stack).reshape(self.observation_shape) -def processed_atari(env_name, shape=84, input_channels=1, n_repeats=4, clip_rewards=False, no_ops=0, fire_first=False): +def processed_atari(env_name, shape=84, input_channels=1, n_repeats=4, action_repeats=4, clip_rewards=False, no_ops=0, fire_first=False): env = gym.make(env_name) - env = RepeatAction(env, n_repeats, clip_rewards, no_ops, fire_first) + env = RepeatAction(env, action_repeats, clip_rewards, no_ops, fire_first) env = Preprocess(env, (shape, shape, input_channels)) env = FrameStacker(env, n_repeats) return env diff --git a/qlearning/experience_replay.py b/qlearning/experience_replay.py index 2509de3..75d3daf 100644 --- a/qlearning/experience_replay.py +++ b/qlearning/experience_replay.py @@ -2,7 +2,7 @@ class ReplayBuffer: - def __init__(self, max_size, state_dim, alpha=1): + def __init__(self, max_size, state_dim, alpha=1, rank=False): self.states = torch.empty(max_size, *state_dim) self.rewards = torch.empty(max_size) self.actions = torch.zeros(max_size, dtype=torch.long) @@ -11,6 +11,7 @@ def __init__(self, max_size, state_dim, alpha=1): self.priorities = torch.zeros(max_size) self.max_size = max_size self.alpha = alpha + self.rank = rank self.ctr = 0 def store(self, state, reward, action, next_state, done, priority=None): @@ -22,7 +23,9 @@ def store(self, state, reward, action, next_state, done, priority=None): self.dones[i] = done if priority is not None: idx, priority = priority - self.priorities[idx] = priority.cpu() + self.priorities[idx] = priority.cpu().pow(0.5) + # setting the new transition to max of priorities to increase proba of using it to update + self.priorities[i] = self.priorities.max().item() else: self.priorities[i] = 1 self.ctr += 1 @@ -31,13 +34,22 @@ def sample(self, batch_size, device, beta=0): max_mem = min(self.ctr, self.max_size) assert max_mem > 0 sample_distribution = self.priorities ** self.alpha + # p_i = 1 / rank(i) + if self.rank: + sample_distribution = 1 / reversed(sample_distribution.argsort()) + + # normalize sample_distribution /= sample_distribution.sum() + + # sample idx = torch.multinomial(sample_distribution, batch_size) states = self.states[idx].to(device) rewards = self.rewards[idx].to(device) - actions = self.actions[idx] + actions = self.actions[idx].to(device) next_states = self.next_states[idx].to(device) dones = self.dones[idx].to(device) + + # importance sampling weights to renormalize sample distribution weights = ((max_mem * sample_distribution[idx]) ** (- beta)).to(device) weights /= weights.max() return states, rewards, actions, next_states, dones, idx, weights diff --git a/qlearning/networks.py b/qlearning/networks.py index cb516e0..0587a26 100644 --- a/qlearning/networks.py +++ b/qlearning/networks.py @@ -17,31 +17,34 @@ def forward(self, state): return self.fc(state_emd) -class QBasic(nn.Module): - def __init__(self, input_channels, n_actions, cpt_dir, name, +class DQN(nn.Module): + def __init__(self, input_channels, out_features, cpt_dir, name, img_size=84, hidden_dim=512, n_repeats=4, channels=[32, 64, 64], - kernel_sizes=[8, 4, 3], strides=[4, 2, 1], noised=False): + kernel_sizes=[8, 4, 3], strides=[4, 2, 1], noised=False, **kwargs): super().__init__() - q_network = [] + feature_extractor = [] # CNN layers prev_ch = input_channels * n_repeats for ch, ks, sd in zip(channels, kernel_sizes, strides): - q_network.append(nn.Conv2d(prev_ch, ch, kernel_size=ks, stride=sd)) - q_network.append(nn.ReLU()) + feature_extractor.append(nn.Conv2d(prev_ch, ch, kernel_size=ks, stride=sd)) + feature_extractor.append(nn.ReLU()) prev_ch = ch - q_network.append(nn.Flatten()) + feature_extractor.append(nn.Flatten()) + + self.feature_extractor = nn.Sequential(*feature_extractor) + q_network = [self.feature_extractor] # find the feature dimension after CNN and flatten dummy_img = torch.empty(1, input_channels * n_repeats, img_size, img_size) - fc_size = nn.Sequential(*q_network)(dummy_img).size(-1) + self.fc_size = self.feature_extractor(dummy_img).size(-1) # FC layers if noised: q_network.extend( - [NoisedLinear(fc_size, hidden_dim), nn.ReLU(), NoisedLinear(hidden_dim, n_actions)]) + [NoisedLinear(self.fc_size, hidden_dim), nn.ReLU(), NoisedLinear(hidden_dim, out_features)]) else: q_network.extend( - [nn.Linear(fc_size, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, n_actions)]) + [nn.Linear(self.fc_size, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, out_features)]) self.q_network = nn.Sequential(*q_network) # training @@ -59,10 +62,10 @@ def load_checkpoint(self): self.load_state_dict(torch.load(self.cpt + '.pth')) -class QDueling(nn.Module): - def __init__(self, input_channels, n_actions, cpt_dir, name, +class DuelingDQN(nn.Module): + def __init__(self, input_channels, out_features, cpt_dir, name, img_size=84, hidden_dim=512, n_repeats=4, channels=[32, 64, 64], - kernel_sizes=[8, 4, 3], strides=[4, 2, 1], noised=False): + kernel_sizes=[8, 4, 3], strides=[4, 2, 1], noised=False, **kwargs): super().__init__() feature_extractor = [] # CNN layers @@ -88,14 +91,13 @@ def __init__(self, input_channels, n_actions, cpt_dir, name, # value & advantage fns if noised: self.value = NoisedLinear(hidden_dim, 1) - self.advantage = NoisedLinear(hidden_dim, n_actions) + self.advantage = NoisedLinear(hidden_dim, out_features) else: self.value = nn.Linear(hidden_dim, 1) - self.advantage = nn.Linear(hidden_dim, n_actions) + self.advantage = nn.Linear(hidden_dim, out_features) # training self.name = name - self.n_actions = n_actions self.cpt = os.path.join(cpt_dir, name) def forward(self, observations): @@ -124,7 +126,7 @@ def __init__(self, in_features, out_features, sigma_init=0.017): init_range = math.sqrt(3 / in_features) self.init_weights(init_range) - def combine_parameters(self): + def assemble_parameters(self): self.reset_epsilon() return self.matrix_mu + self.matrix_sigma * self.matrix_epsilon @@ -149,10 +151,30 @@ def __init__(self, in_features, out_features): def forward(self, state): if self.training: - weight = self.weight.combine_parameters() - bias = self.bias.combine_parameters().squeeze() + weight = self.weight.assemble_parameters() + bias = self.bias.assemble_parameters().squeeze() else: weight = self.weight.matrix_mu bias = self.bias.matrix_mu.squeeze() Qs = state @ weight + bias return Qs + +class CategoricalDQN(DQN): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + for k, v in kwargs.items(): + setattr(self, k, v) + categorical_network = [self.feature_extractor] + if self.noised: + linear_layer = NoisedLinear + else: + linear_layer = nn.Linear + categorical_network.extend( + [linear_layer(self.fc_size, self.hidden_dim), + nn.ReLU(), + linear_layer(self.hidden_dim, self.n_actions * self.num_atoms)]) + self.categorical_network = nn.Sequential(*categorical_network) + + def forward(self, state): + logits = self.categorical_network(state).view(-1, self.num_atoms) + return torch.softmax(logits, dim=-1).view(-1, self.n_actions, self.num_atoms)