diff --git a/kan/spline.py b/kan/spline.py index 48e225087..bf960e82a 100644 --- a/kan/spline.py +++ b/kan/spline.py @@ -130,6 +130,7 @@ def curve2coef(x_eval, y_eval, grid, k, device="cpu"): >>> x_eval = torch.normal(0,1,size=(num_spline, num_sample)) >>> y_eval = torch.normal(0,1,size=(num_spline, num_sample)) >>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1)) + >>> curve2coef(x_eval, y_eval, grid, k, device="cpu").shape torch.Size([5, 13]) ''' # x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar