diff --git a/.gitignore b/.gitignore index ad8827435..c3f9b1934 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,10 @@ docs/_static/ docs/_templates folder test +tutorials/figures/ +tutorials/model_ckpt/model_oct0.ckpt +tutorials/continual_learning_structural.py +tutorials/model_ckpt/model_oct1.ckpt +tutorials/model_ckpt/model_oct2.ckpt +tutorials/continual_learning_step_by_stop.ipynb +tutorials/continual_learning_structural.ipynb diff --git a/kan/KAN.py b/kan/KAN.py index 76e7c321f..d3e1a70b3 100644 --- a/kan/KAN.py +++ b/kan/KAN.py @@ -760,7 +760,7 @@ def score2alpha(score): plt.gcf().get_axes()[0].text(0.5, y0 * (len(self.width) - 1) + 0.2, title, fontsize=40 * scale, horizontalalignment='center', verticalalignment='center') def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., stop_grid_update_step=50, batch=-1, - small_mag_threshold=1e-16, small_reg_factor=1., metrics=None, sglr_avoid=False, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', device='cpu'): + small_mag_threshold=1e-16, small_reg_factor=1., metrics=None, sglr_avoid=False, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', device='cpu', custom_lr_param = []): ''' training @@ -800,6 +800,8 @@ def train(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lam device save_fig_freq : int save figure every (save_fig_freq) step + custom_lr_param: list + custom learning rate parameters for Adam optimizer Returns: -------- @@ -851,7 +853,10 @@ def nonlinear(x, th=small_mag_threshold, factor=small_reg_factor): grid_update_freq = int(stop_grid_update_step / grid_update_num) if opt == "Adam": - optimizer = torch.optim.Adam(self.parameters(), lr=lr) + if custom_lr_param == []: + optimizer = torch.optim.Adam(self.parameters(), lr=lr) + else: + optimizer = torch.optim.Adam(custom_lr_param, lr=lr) elif opt == "LBFGS": optimizer = LBFGS(self.parameters(), lr=lr, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32) diff --git a/kan/utils.py b/kan/utils.py index 7a029a825..1ec1d2791 100644 --- a/kan/utils.py +++ b/kan/utils.py @@ -2,6 +2,7 @@ import torch from sklearn.linear_model import LinearRegression import sympy +import pdb # sigmoid = sympy.Function('sigmoid') # name: (torch implementation, sympy implementation) @@ -39,6 +40,7 @@ def create_dataset(f, ranges = [-1,1], train_num=1000, test_num=1000, + step_size=None, normalize_input=False, normalize_label=False, device='cpu', @@ -56,6 +58,8 @@ def create_dataset(f, the number of training samples. Default: 1000. test_num : int the number of test samples. Default: 1000. + step_size : float + step_size within the specified range. Default: None. normalize_input : bool If True, apply normalization to inputs. Default: False. normalize_label : bool @@ -64,6 +68,7 @@ def create_dataset(f, device. Default: 'cpu'. seed : int random seed. Default: 0. + Returns: -------- @@ -87,13 +92,36 @@ def create_dataset(f, else: ranges = np.array(ranges) - train_input = torch.zeros(train_num, n_var) - test_input = torch.zeros(test_num, n_var) - for i in range(n_var): - train_input[:,i] = torch.rand(train_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0] - test_input[:,i] = torch.rand(test_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0] - - + + if step_size is None: + train_input = torch.zeros(train_num, n_var) + test_input = torch.zeros(test_num, n_var) + for i in range(n_var): + train_input[:,i] = torch.rand(train_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0] + test_input[:,i] = torch.rand(test_num,)*(ranges[i,1]-ranges[i,0])+ranges[i,0] + else: + # generate warning that if step_size is provided, test_num is only the required argument + if test_num is not None: + print('Warning: if step_size is provided, test_num is only the required argument. train_num equals to all possible combinations minus test_num.') + def generate_grid(dimensions, ranges, step_size): + # Generate an array of values for each dimension + axisvalues = [] + for _ in range(dimensions): + axisvalues.append(np.arange(ranges[_][0], ranges[_][1], step_size)) + + # Create the grid using numpy's meshgrid function + grid = np.meshgrid(*axisvalues, indexing='ij') + + # Reshape the grid to list all possible combinations of coordinates + grid = np.stack(grid, axis=-1).reshape(-1, dimensions) + + return grid + + all_data = torch.from_numpy(generate_grid(n_var, ranges, step_size)).float() + train_input = all_data[np.random.choice(all_data.shape[0], all_data.shape[0]-test_num, replace=False)] + test_input = all_data[np.random.choice(all_data.shape[0], test_num, replace=False)] + + train_label = f(train_input) test_label = f(test_input) diff --git a/tutorials/continual_learning_structural.ipynb b/tutorials/continual_learning_structural.ipynb new file mode 100644 index 000000000..a56a71cba --- /dev/null +++ b/tutorials/continual_learning_structural.ipynb @@ -0,0 +1,467 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "\n", + "import site\n", + "import sys\n", + "\n", + "site.addsitedir('../') # Always appends to end\n", + "\n", + "from kan import *\n", + "import numpy as np\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "import pdb\n", + "\n", + "from copy import deepcopy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "import math\n", + "from torch.utils.data import Dataset\n", + "\n", + "# f = lambda x: (x[:,[0]]**2 + x[:,[2]] + 3*torch.sin(x[:,[1]]))/(x[:,[0]] + 2*x[:,[1]] - x[:,[2]])\n", + "# sin(x) * cos(y) + exp(z/2) - x^2 * y\n", + "f = lambda x: torch.sin(x[:, [0]]) * torch.cos(x[:, [1]]) + torch.exp(x[:, [2]]/2) - x[:, [0]]**2 * x[:, [1]]\n", + "\n", + "test_num = int(((10/0.1)**3)*0.2)\n", + "test_num_sub = test_num // 8\n", + "seednum = 1\n", + "\n", + "datasetall = create_dataset(f, n_var=3, test_num=test_num, seed=seednum, step_size=0.1, ranges=[[-5,5], [-5,5], [-5,5]], device=\"cpu\")\n", + "datasetoct0 = create_dataset(f, n_var=3, test_num=test_num_sub, seed=seednum, step_size=0.1, ranges=[[-5,0],[-5,0],[-5,0]], device=\"cpu\")\n", + "# we really do need 3 ranges\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ## build KAN architecture" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Training KAN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "# setting bias_trainable=False, sp_trainable=False, sb_trainable=False is important.\n", + "# otherwise KAN will have random scaling and shift for samples in previous stages\n", + "\n", + "model = KAN(width=[3,3,3,1], grid=100, k=3, sp_trainable=False, sb_trainable=False, noise_scale=0.1, device=\"cpu\")\n", + "\n", + "# make a 3D tensor with 5 points in each dimension\n", + "x = torch.linspace(-5, 5, steps=11).repeat(3, 1).T\n", + "\n", + "model.update_grid_from_samples(x.to(\"cpu\"))\n", + "# model.train(datasetoct0, opt=\"LBFGS\", update_grid=False, steps=200, device=\"cuda\", lr=0.5);\n", + "model.train(datasetoct0, opt=\"LBFGS\", update_grid=False, steps=80, device=\"cpu\");\n", + "# model.train(datasetoct0, opt = 'Adam', steps=5000, update_grid=False, device=\"cuda\", lr=1e-6);\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.save_ckpt(\"model_oct0.ckpt\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "f_pred = model(datasetoct0['test_input'].to('cpu')).to('cpu').detach().numpy()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "f_true = datasetoct0['test_label'].to('cpu').detach().numpy()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "# compute R^2 between prediction and true value\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2: {r2}\")\n", + "\n", + "f_pred = model(datasetall['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetall['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2 of total dataset: {r2}\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " continue pruning until R^2 is close to 0.95" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "# for x in range (1,100):\n", + "# threshold_value = 0.01 * x\n", + "# model = model.prune(threshold=threshold_value)\n", + "# model(datasetoct0['train_input'].to(\"cpu\"))\n", + "# f_pred = model(datasetoct0['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "# f_true = datasetoct0['test_label'].to('cpu').detach().numpy()\n", + "# r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "# print(f\"R^2 of datasetoct0: {r2}\")\n", + "# if r2 < 0.96:\n", + "# break\n", + "\n", + "# model.plot()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "datasetoct1 = create_dataset(f, n_var=3, test_num=test_num_sub, step_size=0.1, seed=seednum, ranges=[[-5, 0], [-5, 0], [0, 5]])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " ### freeze/slow down/regularize/stop early network and determine nodes that is the opposite of pruning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "# model.train(datasetoct1, opt = 'Adam', steps=5000, update_grid=False, device=\"cuda\", lr=lr_half);\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "# model.train(datasetoct1, opt=\"Adam\", steps=int(80/2));\n", + "model.train(datasetoct1, opt=\"LBFGS\", update_grid=False, steps=80, device=\"cpu\");\n", + "f_pred = model(datasetoct1['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetoct1['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2: {r2}\")\n", + "\n", + "f_pred = model(datasetoct0['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetoct0['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2 of datasetoct0: {r2}\")\n", + "\n", + "f_pred = model(datasetall['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetall['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2 of total dataset: {r2}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.save_ckpt(\"model_oct1.ckpt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Also total dataset has bad r^2. Lack of data?\n", + "\n", + " Less catastrophic forgetting PROVIDED that the network is well pruned. Occam's Razor?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "# for x in range (1,100):\n", + "# threshold_value = 0.01 * x\n", + "# model = model.prune(threshold=threshold_value)\n", + "# model(datasetoct1['train_input'].to(\"cpu\"))\n", + "# f_pred = model(datasetoct1['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "# f_true = datasetoct1['test_label'].to('cpu').detach().numpy()\n", + "# r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "# print(f\"R^2 of datasetoct1: {r2}\")\n", + "# if r2 < 0.96:\n", + "# break\n", + "\n", + "# model.plot()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "datasetoct2 = create_dataset(f, n_var=3, test_num=test_num_sub, step_size=0.1, seed=0, ranges=[[-5, 0], [0, 5], [-5, 0]])\n", + "model.train(datasetoct2, opt=\"LBFGS\", steps=80);\n", + "# model = model_bk\n", + "f_pred = model(datasetoct2['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetoct2['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2: {r2}\")\n", + "\n", + "f_pred = model(datasetoct1['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetoct1['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2 of datasetoct1: {r2}\")\n", + "\n", + "f_pred = model(datasetoct0['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetoct0['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2 of datasetoct0: {r2}\")\n", + "\n", + "f_pred = model(datasetall['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetall['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2 of total dataset: {r2}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.save_ckpt(\"model_oct2.ckpt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " A single iteration of LBFGS causes immediate catastrophic forgetting. Also, repeated dataset creation greatly varies r^2. Dataset quality is the problem." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "model = model.prune(threshold=pruning_threshold)\n", + "model(datasetoct2['train_input'].to(\"cpu\"))\n", + "model.plot()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "datasetoct3 = create_dataset(f, n_var=3, train_num=trainpoints_subtask, test_num=testpoints_subtask, seed=0, ranges=[[-5, 0], [0, 5], [0, 5]])\n", + "model.train(datasetoct3, opt=\"LBFGS\", steps=int(80/4));\n", + "f_pred = model(datasetoct3['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetoct3['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2: {r2}\")\n", + "\n", + "f_pred = model(datasetall['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetall['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2 of total dataset: {r2}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "model = model.prune(threshold=pruning_threshold)\n", + "model(datasetoct3['train_input'].to(\"cpu\"))\n", + "model.plot()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "datasetoct4 = create_dataset(f, n_var=3, train_num=trainpoints_subtask, test_num=testpoints_subtask, seed=0, ranges=[[0, 5], [-5, 0], [-5, 0]])\n", + "model.train(datasetoct4, opt=\"LBFGS\", steps=int(80/5));\n", + "f_pred = model(datasetoct4['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetoct4['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2: {r2}\")\n", + "\n", + "f_pred = model(datasetall['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetall['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2 of total dataset: {r2}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "model = model.prune(threshold=pruning_threshold)\n", + "model(datasetoct4['train_input'].to(\"cpu\"))\n", + "model.plot()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "datasetoct5 = create_dataset(f, n_var=3, train_num=trainpoints_subtask, test_num=testpoints_subtask, seed=0, ranges=[[0, 5], [-5, 0], [0, 5]])\n", + "model.train(datasetoct5, opt=\"LBFGS\", steps=int(80/6));\n", + "f_pred = model(datasetoct5['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetoct5['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2: {r2}\")\n", + "\n", + "f_pred = model(datasetall['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetall['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2 of total dataset: {r2}\")\n", + "\n", + "model = model.prune(threshold=pruning_threshold)\n", + "model(datasetoct5['train_input'].to(\"cpu\"))\n", + "model.plot()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "datasetoct6 = create_dataset(f, n_var=3, train_num=trainpoints_subtask, test_num=testpoints_subtask, seed=0, ranges=[[0, 5], [0, 5], [-5, 0]])\n", + "model.train(datasetoct6, opt=\"LBFGS\", steps=int(80/7));\n", + "f_pred = model(datasetoct6['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetoct6['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2: {r2}\")\n", + "\n", + "f_pred = model(datasetall['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetall['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2 of total dataset: {r2}\")\n", + "\n", + "model = model.prune(threshold=pruning_threshold)\n", + "model(datasetoct6['train_input'].to(\"cpu\"))\n", + "model.plot()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %%\n", + "datasetoct7 = create_dataset(f, n_var=3, train_num=trainpoints_subtask, test_num=testpoints_subtask, seed=0, ranges=[[0, 5], [0, 5], [0, 5]])\n", + "model.train(datasetoct7, opt=\"LBFGS\", steps=int(80/8));\n", + "f_pred = model(datasetoct7['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetoct7['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2: {r2}\")\n", + "\n", + "f_pred = model(datasetall['test_input'].to('cpu')).to('cpu').detach().numpy()\n", + "f_true = datasetall['test_label'].to('cpu').detach().numpy()\n", + "r2 = 1 - np.sum((f_pred - f_true)**2)/np.sum(f_true**2)\n", + "print(f\"R^2 of total dataset: {r2}\")\n", + "\n", + "model = model.prune(threshold=pruning_threshold)\n", + "model(datasetoct7['train_input'].to(\"cpu\"))\n", + "model.plot()\n", + "\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pykan-env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}