-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updates from conformer #338
base: main
Are you sure you want to change the base?
Changes from 1 commit
9b5f1ce
8651fc3
fe1261c
822a623
c76e966
3ea7eef
8473f84
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,6 @@ sh/ | |
*.txt | ||
.vscode/ | ||
external/ | ||
playground/ | ||
!requirements.txt | ||
!docs/requirements-docs.txt | ||
.DS_Store | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,14 @@ | |
from torch.distributions import Categorical | ||
from torchtyping import TensorType | ||
|
||
from gflownet.utils.common import copy, set_device, set_float_precision, tbool, tfloat | ||
from gflownet.utils.common import ( | ||
copy, | ||
set_device, | ||
set_float_precision, | ||
tbool, | ||
tfloat, | ||
torch2np, | ||
) | ||
|
||
CMAP = mpl.colormaps["cividis"] | ||
""" | ||
|
@@ -48,7 +55,7 @@ def __init__( | |
# Call reset() to set initial state, done, n_actions | ||
self.reset() | ||
# Device | ||
self.device = set_device(device) | ||
self.set_device(set_device(device)) | ||
# Float precision | ||
self.float = set_float_precision(float_precision) | ||
# Flag to skip checking if action is valid (computing mask) before step | ||
|
@@ -72,6 +79,17 @@ def __init__( | |
self.policy_output_dim = len(self.fixed_policy_output) | ||
self.policy_input_dim = len(self.state2policy()) | ||
|
||
def set_device(self, device: torch.device): | ||
""" | ||
Set the device of the environment. | ||
|
||
Parameters | ||
---------- | ||
device : torch.device | ||
The device to set the environment to. | ||
""" | ||
self.device = device | ||
|
||
@abstractmethod | ||
def get_action_space(self): | ||
""" | ||
|
@@ -757,6 +775,15 @@ def traj2readable(self, traj=None): | |
""" | ||
return str(traj).replace("(", "[").replace(")", "]").replace(",", "") | ||
|
||
def states2kde( | ||
self, states: Union[List, TensorType["batch", "state_dim"]] | ||
) -> Union[List, npt.NDArray, TensorType["batch", "kde_dim"]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe the return type is always |
||
""" | ||
Converts a batch of states into a batch of states suitable for the KDE computations. | ||
""" | ||
states_kde = self.states2proxy(states) | ||
return torch2np(states_kde) | ||
|
||
def reset(self, env_id: Union[int, str] = None): | ||
""" | ||
Resets the environment. | ||
|
@@ -1249,6 +1276,7 @@ def top_k_metrics_and_plots( | |
|
||
return metrics, figs, fig_names | ||
|
||
@torch.no_grad() | ||
def plot_reward_distribution( | ||
self, states=None, scores=None, ax=None, title=None, proxy=None, **kwargs | ||
): | ||
|
@@ -1269,7 +1297,7 @@ def plot_reward_distribution( | |
states_proxy = self.states2proxy(states) | ||
scores = self.proxy(states_proxy) | ||
if isinstance(scores, TensorType): | ||
scores = scores.cpu().detach().numpy() | ||
scores = scores.detach().cpu().numpy() | ||
ax.hist(scores) | ||
ax.set_title(title) | ||
ax.set_ylabel("Number of Samples") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1449,7 +1449,6 @@ def fit_kde( | |
bandwidth : float | ||
The bandwidth of the kernel. | ||
""" | ||
samples = torch2np(samples) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this change because Sklearns supports tensortypes when fitting the KernelDensity? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is because |
||
return KernelDensity(kernel=kernel, bandwidth=bandwidth).fit(samples) | ||
|
||
def plot_reward_samples( | ||
|
@@ -1489,8 +1488,6 @@ def plot_reward_samples( | |
""" | ||
if self.n_dim != 2: | ||
return None | ||
samples = torch2np(samples) | ||
samples_reward = torch2np(samples_reward) | ||
rewards = torch2np(rewards) | ||
# Create mesh grid from samples_reward | ||
n_per_dim = int(np.sqrt(samples_reward.shape[0])) | ||
|
@@ -1543,7 +1540,6 @@ def plot_kde( | |
""" | ||
if self.n_dim != 2: | ||
return None | ||
samples = torch2np(samples) | ||
# Create mesh grid from samples | ||
n_per_dim = int(np.sqrt(samples.shape[0])) | ||
assert n_per_dim**2 == samples.shape[0] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change? Is this correct? :/