diff --git a/roll/distributed/executor/cluster.py b/roll/distributed/executor/cluster.py index 9c18b7c0c..84e624ddc 100644 --- a/roll/distributed/executor/cluster.py +++ b/roll/distributed/executor/cluster.py @@ -140,6 +140,9 @@ def _create_workers(self): if "ROLL_LOG_DIR" in os.environ: env_vars["ROLL_LOG_DIR"] = os.environ["ROLL_LOG_DIR"] env_vars.update(self.worker_config.system_envs) + if current_platform.is_npu(): + env_vars["HCCL_HOST_SOCKET_PORT_RANGE"] = "auto" + env_vars["HCCL_NPU_SOCKET_PORT_RANGE"] = "auto" runtime_env = RuntimeEnv(env_vars=env_vars) self.worker_config.resource_placement_groups = pgs diff --git a/roll/third_party/deepspeed/model_update.py b/roll/third_party/deepspeed/model_update.py index b6452902c..53e8784ce 100644 --- a/roll/third_party/deepspeed/model_update.py +++ b/roll/third_party/deepspeed/model_update.py @@ -29,9 +29,26 @@ def _gather_weights(is_zero3, named_params): return [(n, p.data) for n, p in named_params] -def gather_deepspeed_weights(model, ds_config, buffer_size): +def gather_deepspeed_weights(model, ds_config, buffer_size, is_lora=False): is_zero3 = ds_config.is_zero3() - named_params = [(name, param) for name, param in model.named_parameters()] + if is_lora: + if not is_zero3: + from peft.utils import get_peft_model_state_dict + lora_state_dict = get_peft_model_state_dict(model) + named_params = [(name, param) for name, param in lora_state_dict.items()] + else: + adapter_name = "default" + state_dict = model.state_dict() + lora_state_dict = {k: state_dict[k] for k in state_dict if ("lora_" in k and adapter_name in k)} + named_params = [] + for name, param in lora_state_dict.items(): + clean_name = name.replace(f".{adapter_name}", "") + if clean_name.startswith("base_model.model."): + clean_name = clean_name[len("base_model.model."):] + named_params.append((clean_name, model.get_parameter(name))) + del lora_state_dict + else: + named_params = [(name, param) for name, param in model.named_parameters()] waiting_params, waiting_params_size = [], 0 for name, param in named_params: @@ -150,7 +167,7 @@ def _setup_broadcast_group(self): def _colocated_model_update(self): refs = [] for named_weights in gather_deepspeed_weights( - self.model, self.ds_config, buffer_size=self._model_update_buffer_size + self.model, self.ds_config, buffer_size=self._model_update_buffer_size, is_lora=self.is_lora ): serialized_tensors = serialize_named_weights( named_weights, infer_strategy=self.infer_worker_config.strategy_args.strategy_name @@ -167,7 +184,7 @@ def _colocated_model_update(self): ray.get(refs) refs = [] if co_infer_rank == 0 and self._co_infer_worker is not None: - refs.append(self._co_infer_worker.update_parameter_in_bucket.remote(infer_parallel_tensors)) + refs.append(self._co_infer_worker.update_parameter_in_bucket.remote(infer_parallel_tensors, is_lora=self.is_lora)) if self._broadcast_workers: refs.extend(self._broadcast_to_infer_workers(named_weights)) if refs: @@ -183,6 +200,7 @@ def _broadcast_to_infer_workers(self, named_weights) -> list[ray.ObjectRef]: names=[n for n, _ in named_weights], dtypes=[w.dtype for _, w in named_weights], shapes=[w.shape for _, w in named_weights], + is_lora=self.is_lora, ) for worker in self._broadcast_workers ] @@ -198,7 +216,7 @@ def _broadcast_to_infer_workers(self, named_weights) -> list[ray.ObjectRef]: def _separated_model_update(self): logger.info(f"start broadcast model update {self.model_update_group_name}") for named_weights in gather_deepspeed_weights( - self.model, self.ds_config, buffer_size=self._model_update_buffer_size + self.model, self.ds_config, buffer_size=self._model_update_buffer_size, is_lora=self.is_lora ): refs = self._broadcast_to_infer_workers(named_weights) ray.get(refs)