Skip to content

Commit fb775e0

Browse files
Replace NumPy with Torch in examples/fabric/ (#17279)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ef7da5c commit fb775e0

File tree

7 files changed

+39
-41
lines changed

7 files changed

+39
-41
lines changed

examples/fabric/meta_learning/train_fabric.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
"""
1717
import cherry
1818
import learn2learn as l2l
19-
import numpy as np
2019
import torch
2120

2221
from lightning.fabric import Fabric, seed_everything
@@ -31,10 +30,9 @@ def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways):
3130
data, labels = batch
3231

3332
# Separate data into adaptation/evalutation sets
34-
adaptation_indices = np.zeros(data.size(0), dtype=bool)
35-
adaptation_indices[np.arange(shots * ways) * 2] = True
36-
evaluation_indices = torch.from_numpy(~adaptation_indices)
37-
adaptation_indices = torch.from_numpy(adaptation_indices)
33+
adaptation_indices = torch.zeros(data.size(0), dtype=bool)
34+
adaptation_indices[torch.arange(shots * ways) * 2] = True
35+
evaluation_indices = ~adaptation_indices
3836
adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
3937
evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]
4038

examples/fabric/meta_learning/train_torch.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import cherry
2222
import learn2learn as l2l
23-
import numpy as np
2423
import torch
2524
import torch.distributed as dist
2625

@@ -35,10 +34,9 @@ def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
3534
data, labels = data.to(device), labels.to(device)
3635

3736
# Separate data into adaptation/evalutation sets
38-
adaptation_indices = np.zeros(data.size(0), dtype=bool)
39-
adaptation_indices[np.arange(shots * ways) * 2] = True
40-
evaluation_indices = torch.from_numpy(~adaptation_indices)
41-
adaptation_indices = torch.from_numpy(adaptation_indices)
37+
adaptation_indices = torch.zeros(data.size(0), dtype=bool)
38+
adaptation_indices[torch.arange(shots * ways) * 2] = True
39+
evaluation_indices = ~adaptation_indices
4240
adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
4341
evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]
4442

@@ -76,7 +74,6 @@ def main(
7674
seed = seed + rank
7775

7876
random.seed(seed)
79-
np.random.seed(seed)
8077
torch.manual_seed(seed)
8178
device = torch.device("cpu")
8279
if cuda and torch.cuda.device_count():

examples/fabric/reinforcement_learning/rl/agent.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import math
12
from typing import Dict, Tuple
23

34
import gymnasium as gym
4-
import numpy as np
55
import torch
66
import torch.nn.functional as F
77
from rl.loss import entropy_loss, policy_loss, value_loss
@@ -24,7 +24,8 @@ def __init__(self, envs: gym.vector.SyncVectorEnv, act_fun: str = "relu", ortho_
2424
raise ValueError("Unrecognized activation function: `act_fun` must be either `relu` or `tanh`")
2525
self.critic = torch.nn.Sequential(
2626
layer_init(
27-
torch.nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), ortho_init=ortho_init
27+
torch.nn.Linear(math.prod(envs.single_observation_space.shape), 64),
28+
ortho_init=ortho_init,
2829
),
2930
act_fun,
3031
layer_init(torch.nn.Linear(64, 64), ortho_init=ortho_init),
@@ -33,7 +34,8 @@ def __init__(self, envs: gym.vector.SyncVectorEnv, act_fun: str = "relu", ortho_
3334
)
3435
self.actor = torch.nn.Sequential(
3536
layer_init(
36-
torch.nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), ortho_init=ortho_init
37+
torch.nn.Linear(math.prod(envs.single_observation_space.shape), 64),
38+
ortho_init=ortho_init,
3739
),
3840
act_fun,
3941
layer_init(torch.nn.Linear(64, 64), ortho_init=ortho_init),
@@ -81,10 +83,10 @@ def estimate_returns_and_advantages(
8183
lastgaelam = 0
8284
for t in reversed(range(num_steps)):
8385
if t == num_steps - 1:
84-
nextnonterminal = 1.0 - next_done
86+
nextnonterminal = torch.logical_not(next_done)
8587
nextvalues = next_value
8688
else:
87-
nextnonterminal = 1.0 - dones[t + 1]
89+
nextnonterminal = torch.logical_not(dones[t + 1])
8890
nextvalues = values[t + 1]
8991
delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
9092
advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
@@ -119,7 +121,8 @@ def __init__(
119121
self.normalize_advantages = normalize_advantages
120122
self.critic = torch.nn.Sequential(
121123
layer_init(
122-
torch.nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), ortho_init=ortho_init
124+
torch.nn.Linear(math.prod(envs.single_observation_space.shape), 64),
125+
ortho_init=ortho_init,
123126
),
124127
act_fun,
125128
layer_init(torch.nn.Linear(64, 64), ortho_init=ortho_init),
@@ -128,7 +131,8 @@ def __init__(
128131
)
129132
self.actor = torch.nn.Sequential(
130133
layer_init(
131-
torch.nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64), ortho_init=ortho_init
134+
torch.nn.Linear(math.prod(envs.single_observation_space.shape), 64),
135+
ortho_init=ortho_init,
132136
),
133137
act_fun,
134138
layer_init(torch.nn.Linear(64, 64), ortho_init=ortho_init),
@@ -179,10 +183,10 @@ def estimate_returns_and_advantages(
179183
lastgaelam = 0
180184
for t in reversed(range(num_steps)):
181185
if t == num_steps - 1:
182-
nextnonterminal = 1.0 - next_done
186+
nextnonterminal = torch.logical_not(next_done)
183187
nextvalues = next_value
184188
else:
185-
nextnonterminal = 1.0 - dones[t + 1]
189+
nextnonterminal = torch.logical_not(dones[t + 1])
186190
nextvalues = values[t + 1]
187191
delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
188192
advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam

examples/fabric/reinforcement_learning/rl/utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import argparse
2+
import math
23
import os
34
from distutils.util import strtobool
45
from typing import Optional, TYPE_CHECKING, Union
56

67
import gymnasium as gym
7-
import numpy as np
88
import torch
9-
from torch import Tensor
109
from torch.utils.tensorboard import SummaryWriter
1110

1211
if TYPE_CHECKING:
@@ -119,7 +118,12 @@ def parse_args():
119118
return args
120119

121120

122-
def layer_init(layer: torch.nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0, ortho_init: bool = True):
121+
def layer_init(
122+
layer: torch.nn.Module,
123+
std: float = math.sqrt(2),
124+
bias_const: float = 0.0,
125+
ortho_init: bool = True,
126+
):
123127
if ortho_init:
124128
torch.nn.init.orthogonal_(layer.weight, std)
125129
torch.nn.init.constant_(layer.bias, bias_const)
@@ -157,16 +161,16 @@ def test(
157161
step = 0
158162
done = False
159163
cumulative_rew = 0
160-
next_obs = Tensor(env.reset(seed=args.seed)[0]).to(device)
164+
next_obs = torch.tensor(env.reset(seed=args.seed)[0], device=device)
161165
while not done:
162166
# Act greedly through the environment
163167
action = agent.get_greedy_action(next_obs)
164168

165169
# Single environment step
166170
next_obs, reward, done, truncated, info = env.step(action.cpu().numpy())
167-
done = np.logical_or(done, truncated)
171+
done = done or truncated
168172
cumulative_rew += reward
169-
next_obs = Tensor(next_obs).to(device)
173+
next_obs = torch.tensor(next_obs, device=device)
170174
step += 1
171175
logger.add_scalar("Test/cumulative_reward", cumulative_rew, 0)
172176
env.close()

examples/fabric/reinforcement_learning/train_fabric.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from typing import Dict
2525

2626
import gymnasium as gym
27-
import numpy as np
2827
import torch
2928
import torchmetrics
3029
from rl.agent import PPOLightningAgent
@@ -128,7 +127,7 @@ def main(args: argparse.Namespace):
128127
num_updates = args.total_timesteps // single_global_rollout
129128

130129
# Get the first environment observation and start the optimization
131-
next_obs = Tensor(envs.reset(seed=args.seed)[0]).to(device)
130+
next_obs = torch.tensor(envs.reset(seed=args.seed)[0], device=device)
132131
next_done = torch.zeros(args.num_envs, device=device)
133132
for update in range(1, num_updates + 1):
134133
# Learning rate annealing
@@ -150,9 +149,9 @@ def main(args: argparse.Namespace):
150149

151150
# Single environment step
152151
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
153-
done = np.logical_or(done, truncated)
152+
done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
154153
rewards[step] = torch.tensor(reward, device=device).view(-1)
155-
next_obs, next_done = Tensor(next_obs).to(device), Tensor(done).to(device)
154+
next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device)
156155

157156
if "final_info" in info:
158157
for agent_final_info in info["final_info"]:

examples/fabric/reinforcement_learning/train_fabric_decoupled.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,9 @@
2323
from datetime import datetime
2424

2525
import gymnasium as gym
26-
import numpy as np
2726
import torch
2827
from rl.agent import PPOLightningAgent
2928
from rl.utils import linear_annealing, make_env, parse_args, test
30-
from torch import Tensor
3129
from torch.utils.data import BatchSampler, DistributedSampler
3230
from torchmetrics import MeanMetric
3331

@@ -108,7 +106,7 @@ def player(args, world_collective: TorchCollective, player_trainer_collective: T
108106
world_collective.broadcast(update_t, src=0)
109107

110108
# Get the first environment observation and start the optimization
111-
next_obs = Tensor(envs.reset(seed=args.seed)[0]).to(device)
109+
next_obs = torch.tensor(envs.reset(seed=args.seed)[0], device=device)
112110
next_done = torch.zeros(args.num_envs).to(device)
113111
for update in range(1, num_updates + 1):
114112
for step in range(0, args.num_steps):
@@ -124,9 +122,9 @@ def player(args, world_collective: TorchCollective, player_trainer_collective: T
124122

125123
# Single environment step
126124
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
127-
done = np.logical_or(done, truncated)
128-
rewards[step] = torch.tensor(reward).to(device).view(-1)
129-
next_obs, next_done = Tensor(next_obs).to(device), Tensor(done).to(device)
125+
done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
126+
rewards[step] = torch.tensor(reward, device=device).view(-1)
127+
next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device)
130128

131129
if "final_info" in info:
132130
for agent_final_info in info["final_info"]:

examples/fabric/reinforcement_learning/train_torch.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from typing import Dict
2626

2727
import gymnasium as gym
28-
import numpy as np
2928
import torch
3029
import torch.distributed as distributed
3130
import torch.nn as nn
@@ -118,7 +117,6 @@ def main(args: argparse.Namespace):
118117

119118
# Seed everything
120119
random.seed(args.seed)
121-
np.random.seed(args.seed)
122120
torch.manual_seed(args.seed)
123121
torch.cuda.manual_seed_all(args.seed)
124122
torch.backends.cudnn.deterministic = args.torch_deterministic
@@ -181,7 +179,7 @@ def main(args: argparse.Namespace):
181179
num_updates = args.total_timesteps // single_global_step
182180

183181
# Get the first environment observation and start the optimization
184-
next_obs = Tensor(envs.reset(seed=args.seed)[0]).to(device)
182+
next_obs = torch.tensor(envs.reset(seed=args.seed)[0], device=device)
185183
next_done = torch.zeros(args.num_envs, device=device)
186184
for update in range(1, num_updates + 1):
187185
# Learning rate annealing
@@ -204,9 +202,9 @@ def main(args: argparse.Namespace):
204202

205203
# Single environment step
206204
next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy())
207-
done = np.logical_or(done, truncated)
205+
done = torch.logical_or(torch.tensor(done), torch.tensor(truncated))
208206
rewards[step] = torch.tensor(reward, device=device).view(-1)
209-
next_obs, next_done = Tensor(next_obs).to(device), Tensor(done).to(device)
207+
next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device)
210208

211209
if "final_info" in info:
212210
for agent_final_info in info["final_info"]:

0 commit comments

Comments
 (0)