From 5ab076ac0cc1c09c8382379f19b9ce16b7338bbf Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 5 Dec 2024 13:39:27 +0100 Subject: [PATCH 1/2] no wd on bias and mamba A,D params --- mad/model/pl_model_wrapper.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/mad/model/pl_model_wrapper.py b/mad/model/pl_model_wrapper.py index d754046..533d07b 100644 --- a/mad/model/pl_model_wrapper.py +++ b/mad/model/pl_model_wrapper.py @@ -93,18 +93,29 @@ def test_step(self, return self.phase_step(batch, batch_idx, phase='test') def configure_optimizers(self) -> tp.Union[torch.optim.Optimizer, tp.Dict[str, tp.Any]]: + # param groups + decay_params, no_decay_params = [], [] + for n, p in self.model.named_parameters(): + if p.requires_grad: + if not getattr(p, '_no_weight_decay', False) and ("bias" not in n) and ("norm" not in n): + decay_params.append(p) + else: + no_decay_params.append(p) + param_groups = [ + {"params": decay_params, "weight_decay": self.mad_config.weight_decay}, + {"params": no_decay_params, "weight_decay": 0.0}, + ] + # optimizer: if self.mad_config.optimizer == 'adamw': optimizer = torch.optim.AdamW( - self.parameters(), - lr=self.mad_config.lr, - weight_decay=self.mad_config.weight_decay + param_groups, + lr=self.mad_config.lr ) elif self.mad_config.optimizer == 'sgd': optimizer = torch.optim.SGD( - self.parameters(), - lr=self.mad_config.lr, - weight_decay=self.mad_config.weight_decay + param_groups, + lr=self.mad_config.lr ) else: raise ValueError(f"invalid optimizer: {self.mad_config.optimizer}") From 23592731b2c257169586a8e4d75f476df3de6578 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 5 Dec 2024 13:42:15 +0100 Subject: [PATCH 2/2] add names to normaliz layers, allows to exclude them from wd params --- mad/model/language_model.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mad/model/language_model.py b/mad/model/language_model.py index 8aa029b..e7b7107 100644 --- a/mad/model/language_model.py +++ b/mad/model/language_model.py @@ -1,3 +1,4 @@ +from collections import OrderedDict import torch import typing as tp from torch import nn @@ -46,9 +47,16 @@ def __init__(self, self.model = nn.ModuleList([]) for layer, layer_cfg in zip(layers, layer_cfgs): - self.model.append(nn.Sequential(norm(layer_cfg['dim']), layer(**layer_cfg))) - - self.unembed = nn.Sequential(norm(layer_cfg['dim']), nn.Linear(dim, vocab_size)) + self.model.append(nn.Sequential(OrderedDict([ + ('norm', norm(layer_cfg['dim'])), + ('layer', layer(**layer_cfg)) + ]))) + + self.unembed = nn.Sequential(OrderedDict([ + ('norm', norm(layer_cfg['dim'])), + ('lm_head', nn.Linear(dim, vocab_size)) + ])) + self.apply(self._init_weights) def embed(self,