Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/Examples/Example_4_symbolic_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
],
"source": [
"# train the model\n",
"model.train(dataset, opt=\"LBFGS\", steps=20, lamb=0.01, lamb_entropy=10.);"
"model.train(dataset, opt=\"LBFGS\", steps=20, lamb=0.01, lamb_entropy=10.)"
]
},
{
Expand Down Expand Up @@ -286,7 +286,7 @@
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=20);\n",
"model.train(dataset, opt=\"LBFGS\", steps=20)\n",
"model.plot()"
]
},
Expand Down Expand Up @@ -465,7 +465,7 @@
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=20);\n",
"model.train(dataset, opt=\"LBFGS\", steps=20)\n",
"model.plot()"
]
},
Expand Down Expand Up @@ -612,7 +612,7 @@
],
"source": [
"# this loss is stuck at around 1e-3 RMSE, which is good, but not machine precision.\n",
"model.train(dataset, opt=\"LBFGS\", steps=20);\n",
"model.train(dataset, opt=\"LBFGS\", steps=20)\n",
"model.plot()"
]
},
Expand Down
14 changes: 7 additions & 7 deletions hellokan.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@
],
"source": [
"# plot KAN at initialization\n",
"model(dataset['train_input']);\n",
"model(dataset['train_input'])\n",
"model.plot(beta=100)"
]
},
Expand Down Expand Up @@ -208,7 +208,7 @@
],
"source": [
"# train the model\n",
"model.train(dataset, opt=\"LBFGS\", steps=20, lamb=0.01, lamb_entropy=10.);"
"model.train(dataset, opt=\"LBFGS\", steps=20, lamb=0.01, lamb_entropy=10.)"
]
},
{
Expand Down Expand Up @@ -324,7 +324,7 @@
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=50);"
"model.train(dataset, opt=\"LBFGS\", steps=50)"
]
},
{
Expand Down Expand Up @@ -377,9 +377,9 @@
"\n",
"if mode == \"manual\":\n",
" # manual mode\n",
" model.fix_symbolic(0,0,0,'sin');\n",
" model.fix_symbolic(0,1,0,'x^2');\n",
" model.fix_symbolic(1,0,0,'exp');\n",
" model.fix_symbolic(0,0,0,'sin')\n",
" model.fix_symbolic(0,1,0,'x^2')\n",
" model.fix_symbolic(1,0,0,'exp')\n",
"elif mode == \"auto\":\n",
" # automatic mode\n",
" lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']\n",
Expand Down Expand Up @@ -409,7 +409,7 @@
}
],
"source": [
"model.train(dataset, opt=\"LBFGS\", steps=50);"
"model.train(dataset, opt=\"LBFGS\", steps=50)"
]
},
{
Expand Down
38 changes: 19 additions & 19 deletions kan/KAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class KAN(nn.Module):
depth: int
depth of KAN
width: list
number of neurons in each layer. e.g., [2,5,5,3] means 2D inputs, 5D outputs, with 2 layers of 5 hidden neurons.
number of neurons in each layer. e.g., [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons.
grid: int
the number of grid intervals
k: int
Expand Down Expand Up @@ -185,7 +185,7 @@ def initialize_from_another_model(self, another_model, x):
>>> model_fine = KAN(width=[2,5,1], grid=10, k=3)
>>> print(model_fine.act_fun[0].coef[0][0].data)
>>> x = torch.normal(0,1,size=(100,2))
>>> model_fine.initialize_from_another_model(model_coarse, x);
>>> model_fine.initialize_from_another_model(model_coarse, x)
>>> print(model_fine.act_fun[0].coef[0][0].data)
tensor(-0.0030)
tensor(0.0506)
Expand Down Expand Up @@ -348,27 +348,27 @@ def set_mode(self, l, i, j, mode, mask_n=None):
output neuron index
mode : str
'n' (numeric) or 's' (symbolic) or 'ns' (combined)
mask_n : None or float)
mask_n : None or float
magnitude of the numeric front

Returns:
--------
None
'''
if mode == "s":
mask_n = 0.;
mask_n = 0.
mask_s = 1.
elif mode == "n":
mask_n = 1.;
mask_n = 1.
mask_s = 0.
elif mode == "sn" or mode == "ns":
if mask_n == None:
if mask_n is None:
mask_n = 1.
else:
mask_n = mask_n
mask_s = 1.
else:
mask_n = 0.;
mask_n = 0.
mask_s = 0.

self.act_fun[l].mask.data[j * self.act_fun[l].in_dim + i] = mask_n
Expand Down Expand Up @@ -814,7 +814,7 @@ def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lam
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01)
>>> model.plot()
'''

Expand Down Expand Up @@ -842,7 +842,7 @@ def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor):

pbar = tqdm(range(steps), desc='description', ncols=100)

if loss_fn == None:
if loss_fn is None:
loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2)
else:
loss_fn = loss_fn_eval = loss_fn
Expand Down Expand Up @@ -958,7 +958,7 @@ def prune(self, threshold=1e-2, mode="auto", active_neurons_id=None):
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01)
>>> model.prune()
>>> model.plot(mask=True)
'''
Expand Down Expand Up @@ -1063,7 +1063,7 @@ def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=No
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01)
>>> model = model.prune()
>>> model(dataset['train_input'])
>>> model.suggest_symbolic(0,0,0)
Expand All @@ -1076,7 +1076,7 @@ def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=No
'''
r2s = []

if lib == None:
if lib is None:
symbolic_lib = SYMBOLIC_LIB
else:
symbolic_lib = {}
Expand Down Expand Up @@ -1124,8 +1124,8 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
>>> >>> model = model.prune()
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01)
>>> model = model.prune()
>>> model(dataset['train_input'])
>>> model.auto_symbolic()
fixing (0,0,0) with sin, r2=0.9994837045669556
Expand All @@ -1139,8 +1139,8 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
>>> >>> model = model.prune()
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01)
>>> model = model.prune()
>>> model(dataset['train_input'])
>>> model.auto_symbolic(lib=['exp','sin','x^2'])
fixing (0,0,0) with sin, r2=0.999411404132843
Expand Down Expand Up @@ -1184,11 +1184,11 @@ def symbolic_formula(self, floating_digit=2, var=None, normalizer=None, simplify
>>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.1, seed=0, grid_eps=0.02)
>>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
>>> dataset = create_dataset(f, n_var=2)
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01);
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.01)
>>> model = model.prune()
>>> model(dataset['train_input'])
>>> model.auto_symbolic(lib=['exp','sin','x^2'])
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.00, update_grid=False);
>>> model.train(dataset, opt='LBFGS', steps=50, lamb=0.00, update_grid=False)
>>> model.symbolic_formula()
'''
symbolic_acts = []
Expand All @@ -1202,7 +1202,7 @@ def ex_round(ex1, floating_digit=floating_digit):
return ex2

# define variables
if var == None:
if var is None:
for ii in range(1, self.width[0] + 1):
exec(f"x{ii} = sympy.Symbol('x_{ii}')")
exec(f"x.append(x{ii})")
Expand Down