diff --git a/kan/KAN.py b/kan/KAN.py index 4e981d0e..e4310d95 100644 --- a/kan/KAN.py +++ b/kan/KAN.py @@ -143,7 +143,7 @@ def __init__(self, width=None, grid=3, k=3, noise_scale=0.1, noise_scale_base=0. # bias bias = nn.Linear(width[l + 1], 1, bias=False, device=device).requires_grad_(bias_trainable) - bias.weight.data *= 0. + torch.nn.init.zeros_(bias.weight) self.biases.append(bias) self.biases = nn.ModuleList(self.biases) @@ -154,12 +154,8 @@ def __init__(self, width=None, grid=3, k=3, noise_scale=0.1, noise_scale_base=0. self.base_fun = base_fun ### initializing the symbolic front ### - self.symbolic_fun = [] - for l in range(self.depth): - sb_batch = Symbolic_KANLayer(in_dim=width[l], out_dim=width[l + 1], device=device) - self.symbolic_fun.append(sb_batch) - - self.symbolic_fun = nn.ModuleList(self.symbolic_fun) + self.symbolic_fun = nn.ModuleList( + [Symbolic_KANLayer(in_dim=width[l], out_dim=width[l + 1], device=device) for l in range(self.depth)]) self.symbolic_enabled = symbolic_enabled self.device = device @@ -185,7 +181,7 @@ def initialize_from_another_model(self, another_model, x): >>> model_fine = KAN(width=[2,5,1], grid=10, k=3) >>> print(model_fine.act_fun[0].coef[0][0].data) >>> x = torch.normal(0,1,size=(100,2)) - >>> model_fine.initialize_from_another_model(model_coarse, x); + >>> model_fine.initialize_from_another_model(model_coarse, x) >>> print(model_fine.act_fun[0].coef[0][0].data) tensor(-0.0030) tensor(0.0506) @@ -356,10 +352,10 @@ def set_mode(self, l, i, j, mode, mask_n=None): None ''' if mode == "s": - mask_n = 0.; + mask_n = 0. mask_s = 1. elif mode == "n": - mask_n = 1.; + mask_n = 1. mask_s = 0. elif mode == "sn" or mode == "ns": if mask_n == None: @@ -368,7 +364,7 @@ def set_mode(self, l, i, j, mode, mask_n=None): mask_n = mask_n mask_s = 1. else: - mask_n = 0.; + mask_n = 0. mask_s = 0. self.act_fun[l].mask.data[j * self.act_fun[l].in_dim + i] = mask_n @@ -814,7 +810,7 @@ def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lam >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0) >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) >>> dataset = create_dataset(f, n_var=2) - >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); + >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01) >>> model.plot() ''' @@ -958,7 +954,7 @@ def prune(self, threshold=1e-2, mode="auto", active_neurons_id=None): >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0) >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) >>> dataset = create_dataset(f, n_var=2) - >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); + >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01) >>> model.prune() >>> model.plot(mask=True) ''' @@ -1063,7 +1059,7 @@ def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=No >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0) >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) >>> dataset = create_dataset(f, n_var=2) - >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); + >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01) >>> model = model.prune() >>> model(dataset['train_input']) >>> model.suggest_symbolic(0,0,0) @@ -1124,7 +1120,7 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose= >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0) >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) >>> dataset = create_dataset(f, n_var=2) - >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); + >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01) >>> >>> model = model.prune() >>> model(dataset['train_input']) >>> model.auto_symbolic() @@ -1139,7 +1135,7 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose= >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0) >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) >>> dataset = create_dataset(f, n_var=2) - >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); + >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01) >>> >>> model = model.prune() >>> model(dataset['train_input']) >>> model.auto_symbolic(lib=['exp','sin','x^2']) @@ -1184,11 +1180,11 @@ def symbolic_formula(self, floating_digit=2, var=None, normalizer=None, simplify >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0, grid_eps=0.02) >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2) >>> dataset = create_dataset(f, n_var=2) - >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01); + >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01) >>> model = model.prune() >>> model(dataset['train_input']) >>> model.auto_symbolic(lib=['exp','sin','x^2']) - >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.00, update_grid=False); + >>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.00, update_grid=False) >>> model.symbolic_formula() ''' symbolic_acts = [] diff --git a/kan/KANLayer.py b/kan/KANLayer.py index fb35d1f1..57dafddf 100644 --- a/kan/KANLayer.py +++ b/kan/KANLayer.py @@ -116,8 +116,7 @@ def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.1, scale_base= # shape: (size, num) self.grid = torch.einsum('i,j->ij', torch.ones(size, device=device), torch.linspace(grid_range[0], grid_range[1], steps=num + 1, device=device)) self.grid = torch.nn.Parameter(self.grid).requires_grad_(False) - noises = (torch.rand(size, self.grid.shape[1]) - 1 / 2) * noise_scale / num - noises = noises.to(device) + noises = (torch.rand(size, self.grid.shape[1], device=device) - 1 / 2) * noise_scale / num # shape: (size, coef) self.coef = torch.nn.Parameter(curve2coef(self.grid, noises, self.grid, k, device)) if isinstance(scale_base, float):