Skip to content

Commit

Permalink
update Example tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
KindXiaoming committed Aug 11, 2024
1 parent 6ebd78b commit 31b2e37
Show file tree
Hide file tree
Showing 36 changed files with 5,126 additions and 774 deletions.
65 changes: 45 additions & 20 deletions kan/.ipynb_checkpoints/KANLayer-checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class KANLayer(nn.Module):
unlock already locked activation functions
"""

def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.1, scale_base=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, save_plot_data = True, device='cpu', sparse_init=False):
def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, save_plot_data = True, device='cpu', sparse_init=False):
''''
initialize a KANLayer
Expand Down Expand Up @@ -119,7 +119,7 @@ def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.1, scale_base=
grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[None,:].expand(self.in_dim, num+1)
grid = extend_grid(grid, k_extend=k)
self.grid = torch.nn.Parameter(grid).requires_grad_(False)
noises = (torch.rand(self.num+1, self.in_dim, self.out_dim) - 1 / 2) * noise_scale / num
noises = (torch.rand(self.num+1, self.in_dim, self.out_dim) - 1/2) * noise_scale / num
# shape: (size, coef)
self.coef = torch.nn.Parameter(curve2coef(self.grid[:,k:-k].permute(1,0), noises, self.grid, k))
#if isinstance(scale_base, float):
Expand All @@ -128,7 +128,9 @@ def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.1, scale_base=
else:
mask = 1.

self.scale_base = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_base * mask).requires_grad_(sb_trainable) # make scale trainable
self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(in_dim) + \
scale_base_sigma * (torch.rand(in_dim, out_dim)*2-1) * 1/np.sqrt(in_dim))
#self.scale_base = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_base * mask).requires_grad_(sb_trainable) # make scale trainable
#else:
#self.scale_base = torch.nn.Parameter(scale_base.to(device)).requires_grad_(sb_trainable)
self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * mask).requires_grad_(sp_trainable) # make scale trainable
Expand Down Expand Up @@ -193,7 +195,7 @@ def forward(self, x):
y = torch.sum(y, dim=1) # shape (batch, out_dim)
return y, preacts, postacts, postspline

def update_grid_from_samples(self, x):
def update_grid_from_samples(self, x, mode='sample'):
'''
update grid from samples
Expand All @@ -216,21 +218,32 @@ def update_grid_from_samples(self, x):
tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]])
tensor([[-3.0002, -1.7882, -0.5763, 0.6357, 1.8476, 3.0002]])
'''

batch = x.shape[0]
#x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
x_pos = torch.sort(x, dim=0)[0]
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
num_interval = self.grid.shape[1] - 1 - 2*self.k
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[ids, :].permute(1,0)
margin = 0.01
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive

def get_grid(num_interval):
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[ids, :].permute(1,0)
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
return grid

grid = get_grid(num_interval)

if mode == 'grid':
sample_grid = get_grid(2*num_interval)
x_pos = sample_grid.permute(1,0)
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)

self.grid.data = extend_grid(grid, k_extend=self.k)
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)

def initialize_grid_from_parent(self, parent, x):
def initialize_grid_from_parent(self, parent, x, mode='sample'):
'''
update grid from a parent KANLayer & samples
Expand Down Expand Up @@ -258,19 +271,31 @@ def initialize_grid_from_parent(self, parent, x):
tensor([[-1.0000, -0.8000, -0.6000, -0.4000, -0.2000, 0.0000, 0.2000, 0.4000,
0.6000, 0.8000, 1.0000]])
'''

batch = x.shape[0]
# preacts: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
#x_eval = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
x_eval = x
pgrid = parent.grid # (in_dim, G+2*k+1)
pk = parent.k
y_eval = coef2curve(x_eval, pgrid, parent.coef, pk)

h = (pgrid[:,[-pk]] - pgrid[:,[pk]])/self.num
grid = pgrid[:,[pk]] + torch.arange(self.num+1,) * h
x_pos = torch.sort(x, dim=0)[0]
y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
num_interval = self.grid.shape[1] - 1 - 2*self.k

def get_grid(num_interval):
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[ids, :].permute(1,0)
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
return grid

grid = get_grid(num_interval)

if mode == 'grid':
sample_grid = get_grid(2*num_interval)
x_pos = sample_grid.permute(1,0)
y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)

grid = extend_grid(grid, k_extend=self.k)
self.grid.data = grid
self.coef.data = curve2coef(x_eval, y_eval, self.grid, self.k)
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)

def get_subset(self, in_id, out_id):
'''
Expand Down
13 changes: 7 additions & 6 deletions kan/.ipynb_checkpoints/MultKAN-checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
class MultKAN(nn.Module):

# include mult_ops = []
def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=1.0, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=1.0, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu'):
def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu'):

super(MultKAN, self).__init__()

Expand Down Expand Up @@ -60,16 +60,16 @@ def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=1.0, sca
base_fun = torch.nn.SiLU()
elif base_fun == 'identity':
base_fun = torch.nn.Identity()
elif base_fun == 'zero':
base_fun = lambda x: x*0.

self.grid_eps = grid_eps
self.grid_range = grid_range


for l in range(self.depth):
# splines
scale_base = scale_base_mu * 1 / np.sqrt(width_in[l]) + \
scale_base_sigma * (torch.randn(width_in[l], width_out[l + 1]) * 2 - 1) * 1/np.sqrt(width_in[l])
sp_batch = KANLayer(in_dim=width_in[l], out_dim=width_out[l+1], num=grid, k=k, noise_scale=noise_scale, scale_base=scale_base, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, sparse_init=sparse_init)
sp_batch = KANLayer(in_dim=width_in[l], out_dim=width_out[l+1], num=grid, k=k, noise_scale=noise_scale, scale_base_mu=scale_base_mu, scale_base_sigma=scale_base_sigma, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, sparse_init=sparse_init)
self.act_fun.append(sp_batch)

self.node_bias = []
Expand Down Expand Up @@ -185,7 +185,7 @@ def initialize_from_another_model(self, another_model, x):
for l in range(self.depth):
self.symbolic_fun[l] = another_model.symbolic_fun[l]

return self.to(device)
return self.to(self.device)

def log_history(self, method_name):

Expand Down Expand Up @@ -221,7 +221,8 @@ def refine(self, new_grid):
auto_save=True,
first_init=False,
state_id=self.state_id,
round=self.round)
round=self.round,
device=self.device)

model_new.initialize_from_another_model(self, self.cache_data)
model_new.cache_data = self.cache_data
Expand Down
18 changes: 14 additions & 4 deletions kan/.ipynb_checkpoints/spline-checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def coef2curve(x_eval, grid, coef, k, device="cpu"):
return y_eval


def curve2coef(x_eval, y_eval, grid, k):
def curve2coef(x_eval, y_eval, grid, k, lamb=1e-8):
'''
converting B-spline curves to B-spline coefficients using least squares.
Expand Down Expand Up @@ -163,10 +163,20 @@ def curve2coef(x_eval, y_eval, grid, k):
mat = mat.permute(1,0,2)[:,None,:,:].expand(in_dim, out_dim, batch, n_coef) # (in_dim, out_dim, batch, n_coef)
# coef shape: (in_dim, outdim, G+k)
y_eval = y_eval.permute(1,2,0).unsqueeze(dim=3) # y_eval: (in_dim, out_dim, batch, 1)
#print(mat)
device = mat.device
coef = torch.linalg.lstsq(mat, y_eval,
driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]


#coef = torch.linalg.lstsq(mat, y_eval,
#driver='gelsy' if device == 'cpu' else 'gels').solution[:,:,:,0]

XtX = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), mat)
Xty = torch.einsum('ijmn,ijnp->ijmp', mat.permute(0,1,3,2), y_eval)
n1, n2, n = XtX.shape[0], XtX.shape[1], XtX.shape[2]
identity = torch.eye(n,n)[None, None, :, :].expand(n1, n2, n, n).to(device)
A = XtX + lamb * identity
B = Xty
coef = (A.pinverse() @ B)[:,:,:,0]

return coef


Expand Down
65 changes: 45 additions & 20 deletions kan/KANLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class KANLayer(nn.Module):
unlock already locked activation functions
"""

def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.1, scale_base=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, save_plot_data = True, device='cpu', sparse_init=False):
def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, save_plot_data = True, device='cpu', sparse_init=False):
''''
initialize a KANLayer
Expand Down Expand Up @@ -119,7 +119,7 @@ def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.1, scale_base=
grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[None,:].expand(self.in_dim, num+1)
grid = extend_grid(grid, k_extend=k)
self.grid = torch.nn.Parameter(grid).requires_grad_(False)
noises = (torch.rand(self.num+1, self.in_dim, self.out_dim) - 1 / 2) * noise_scale / num
noises = (torch.rand(self.num+1, self.in_dim, self.out_dim) - 1/2) * noise_scale / num
# shape: (size, coef)
self.coef = torch.nn.Parameter(curve2coef(self.grid[:,k:-k].permute(1,0), noises, self.grid, k))
#if isinstance(scale_base, float):
Expand All @@ -128,7 +128,9 @@ def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.1, scale_base=
else:
mask = 1.

self.scale_base = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_base * mask).requires_grad_(sb_trainable) # make scale trainable
self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(in_dim) + \
scale_base_sigma * (torch.rand(in_dim, out_dim)*2-1) * 1/np.sqrt(in_dim))
#self.scale_base = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_base * mask).requires_grad_(sb_trainable) # make scale trainable
#else:
#self.scale_base = torch.nn.Parameter(scale_base.to(device)).requires_grad_(sb_trainable)
self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * mask).requires_grad_(sp_trainable) # make scale trainable
Expand Down Expand Up @@ -193,7 +195,7 @@ def forward(self, x):
y = torch.sum(y, dim=1) # shape (batch, out_dim)
return y, preacts, postacts, postspline

def update_grid_from_samples(self, x):
def update_grid_from_samples(self, x, mode='sample'):
'''
update grid from samples
Expand All @@ -216,21 +218,32 @@ def update_grid_from_samples(self, x):
tensor([[-1.0000, -0.6000, -0.2000, 0.2000, 0.6000, 1.0000]])
tensor([[-3.0002, -1.7882, -0.5763, 0.6357, 1.8476, 3.0002]])
'''

batch = x.shape[0]
#x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
x_pos = torch.sort(x, dim=0)[0]
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
num_interval = self.grid.shape[1] - 1 - 2*self.k
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[ids, :].permute(1,0)
margin = 0.01
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive

def get_grid(num_interval):
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[ids, :].permute(1,0)
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
return grid

grid = get_grid(num_interval)

if mode == 'grid':
sample_grid = get_grid(2*num_interval)
x_pos = sample_grid.permute(1,0)
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)

self.grid.data = extend_grid(grid, k_extend=self.k)
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)

def initialize_grid_from_parent(self, parent, x):
def initialize_grid_from_parent(self, parent, x, mode='sample'):
'''
update grid from a parent KANLayer & samples
Expand Down Expand Up @@ -258,19 +271,31 @@ def initialize_grid_from_parent(self, parent, x):
tensor([[-1.0000, -0.8000, -0.6000, -0.4000, -0.2000, 0.0000, 0.2000, 0.4000,
0.6000, 0.8000, 1.0000]])
'''

batch = x.shape[0]
# preacts: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)
#x_eval = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
x_eval = x
pgrid = parent.grid # (in_dim, G+2*k+1)
pk = parent.k
y_eval = coef2curve(x_eval, pgrid, parent.coef, pk)

h = (pgrid[:,[-pk]] - pgrid[:,[pk]])/self.num
grid = pgrid[:,[pk]] + torch.arange(self.num+1,) * h
x_pos = torch.sort(x, dim=0)[0]
y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)
num_interval = self.grid.shape[1] - 1 - 2*self.k

def get_grid(num_interval):
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
grid_adaptive = x_pos[ids, :].permute(1,0)
h = (grid_adaptive[:,[-1]] - grid_adaptive[:,[0]])/num_interval
grid_uniform = grid_adaptive[:,[0]] + h * torch.arange(num_interval+1,)[None, :].to(x.device)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
return grid

grid = get_grid(num_interval)

if mode == 'grid':
sample_grid = get_grid(2*num_interval)
x_pos = sample_grid.permute(1,0)
y_eval = coef2curve(x_pos, parent.grid, parent.coef, parent.k)

grid = extend_grid(grid, k_extend=self.k)
self.grid.data = grid
self.coef.data = curve2coef(x_eval, y_eval, self.grid, self.k)
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)

def get_subset(self, in_id, out_id):
'''
Expand Down
Loading

0 comments on commit 31b2e37

Please sign in to comment.