Skip to content

Commit

Permalink
fixing vec field plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
arora-tushar committed Oct 10, 2022
1 parent 958d3de commit 6850511
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions code_pack/plotting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
import numpy as np
import torch

def plot_two_d_vector_field_from_data(dynamics_func, axs, axs_range, P=None):
x = np.linspace(min(axs_range['x_min'], -2), max(axs_range['x_max'], 2), 25)
Expand All @@ -16,11 +17,13 @@ def plot_two_d_vector_field_from_data(dynamics_func, axs, axs_range, P=None):
y = Y[i, j]

vec_in = np.array([x, y])
# ode always needs 0th time point, so we take the first mapping which is not 0
try:

if('torch.nn.modules' in str(type(dynamics_func))):
vec_out = np.asarray(dynamics_func(torch.tensor(vec_in, dtype=torch.float32)))
else:
# ode always needs 0th time point, so we take the first mapping which is not 0
vec_out = dynamics_func(vec_in)[1]
except:
vec_out = dynamics_fn(vec_in)


if P is None:
s = (vec_out - vec_in)
Expand All @@ -39,4 +42,4 @@ def raster_to_events(raster):
row = raster[:, i]
rowidx = np.nonzero(row)[0]
events.append(rowidx)
return events
return events

0 comments on commit 6850511

Please sign in to comment.