From c68c872d8a25d3f8afc618710ed8e7b160ac7e8e Mon Sep 17 00:00:00 2001 From: CadmusFLux Date: Sat, 25 May 2024 14:29:46 +0100 Subject: [PATCH] Adding relevant line to explain the torch shape output in curve2coeff function --- kan/spline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/kan/spline.py b/kan/spline.py index 48e225087..bf960e82a 100644 --- a/kan/spline.py +++ b/kan/spline.py @@ -130,6 +130,7 @@ def curve2coef(x_eval, y_eval, grid, k, device="cpu"): >>> x_eval = torch.normal(0,1,size=(num_spline, num_sample)) >>> y_eval = torch.normal(0,1,size=(num_spline, num_sample)) >>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1)) + >>> curve2coef(x_eval, y_eval, grid, k, device="cpu").shape torch.Size([5, 13]) ''' # x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar