diff --git a/kan/MultKAN.py b/kan/MultKAN.py index 6074675ff..268d51a0e 100644 --- a/kan/MultKAN.py +++ b/kan/MultKAN.py @@ -1680,7 +1680,7 @@ def prune_node(self, threshold=1e-2, mode="auto", active_neurons_id=None, log_hi if i not in active_neurons_down[l]: self.remove_node(l + 1, i, mode='down',log_history=False) - model2 = MultKAN(copy.deepcopy(self.width), grid=self.grid, k=self.k, base_fun=self.base_fun_name, mult_arity=self.mult_arity, ckpt_path=self.ckpt_path, auto_save=True, first_init=False, state_id=self.state_id, round=self.round).to(self.device) + model2 = MultKAN(copy.deepcopy(self.width), grid=self.grid, k=self.k, base_fun=self.base_fun_name, mult_arity=self.mult_arity, ckpt_path=self.ckpt_path, auto_save=self.auto_save, first_init=False, state_id=self.state_id, round=self.round).to(self.device) model2.load_state_dict(self.state_dict()) width_new = [self.width[0]]