Skip to content

Commit

Permalink
Merge pull request #173 from AlessandroFlati/master
Browse files Browse the repository at this point in the history
Runtime Error in hellokan.ipynb
  • Loading branch information
KindXiaoming authored May 12, 2024
2 parents 1dad6bc + f34caec commit 10c456a
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions kan/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def extend_grid(grid, k_extend=0):
value = (x >= grid[:, :-1]) * (x < grid[:, 1:])
else:
B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1, extend=False, device=device)
value = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)]) * B_km1[:, :-1] + (grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * B_km1[:, 1:]
value = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)]) * B_km1[:, :-1] + (
grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * B_km1[:, 1:]
return value


Expand Down Expand Up @@ -129,10 +130,11 @@ def curve2coef(x_eval, y_eval, grid, k, device="cpu"):
>>> x_eval = torch.normal(0,1,size=(num_spline, num_sample))
>>> y_eval = torch.normal(0,1,size=(num_spline, num_sample))
>>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1))
>>> curve2coef(x_eval, y_eval, grids, k=k).shape
torch.Size([5, 13])
'''
# x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar
mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1).to(y_eval.dtype)
coef = torch.linalg.lstsq(mat.to('cpu'), y_eval.unsqueeze(dim=2).to('cpu')).solution[:, :, 0] # sometimes 'cuda' version may diverge
mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1)
# coef = torch.linalg.lstsq(mat, y_eval.unsqueeze(dim=2)).solution[:, :, 0]
coef = torch.linalg.lstsq(mat.to(device), y_eval.unsqueeze(dim=2).to(device),
driver='gelsy' if device == 'cpu' else 'gels').solution[:, :, 0]
return coef.to(device)

0 comments on commit 10c456a

Please sign in to comment.