From 205cd615bca7ea39e624a16be3a467727bfc350a Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Fri, 6 Feb 2026 03:37:50 +0800 Subject: [PATCH 1/3] feature(pu): adapt to npu --- ding/policy/base_policy.py | 44 +++-- ding/torch_utils/__init__.py | 2 + ding/torch_utils/device_helper.py | 183 ++++++++++++++++++ .../cartpole/config/cartpole_ppo_config.py | 2 +- 4 files changed, 219 insertions(+), 12 deletions(-) create mode 100644 ding/torch_utils/device_helper.py diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index 1c5f32d1db..34218422f3 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -9,6 +9,7 @@ from ding.model import create_model from ding.utils import import_module, allreduce, allreduce_with_indicator, broadcast, get_rank, allreduce_async, \ synchronize, deep_merge_dicts, POLICY_REGISTRY +from ding.torch_utils import auto_device_init, move_to_device class Policy(ABC): @@ -83,8 +84,12 @@ def default_config(cls: type) -> EasyDict: config = dict( # (bool) Whether the learning policy is the same as the collecting data policy (on-policy). on_policy=False, - # (bool) Whether to use cuda in policy. + # (bool) Whether to use cuda in policy (deprecated, use 'device' instead). cuda=False, + # (str) Device to use for policy. Can be 'auto', 'cuda', 'npu', or 'cpu'. + # 'auto' will automatically detect NPU > GPU > CPU. + # If not specified, will use 'cuda' config for backward compatibility. + device='auto', # (bool) Whether to use data parallel multi-gpu mode in policy. multi_gpu=False, # (bool) Whether to synchronize update the model parameters after allreduce the gradients of model parameters. @@ -136,25 +141,42 @@ def __init__( if len(set(self._enable_field).intersection(set(['learn', 'collect', 'eval']))) > 0: model = self._create_model(cfg, model) - self._cuda = cfg.cuda and torch.cuda.is_available() + + # Device initialization with auto-detection support for NPU/GPU/CPU + # Backward compatibility: if 'device' not in cfg, use 'cuda' config + if hasattr(cfg, 'device') and cfg.device is not None: + # New way: use 'device' config for auto-detection or explicit setting + cfg_device = cfg.device + else: + # Legacy way: convert 'cuda' boolean to device string + cfg_device = 'cuda' if (hasattr(cfg, 'cuda') and cfg.cuda) else 'cpu' + # now only support multi-gpu for only enable learn mode if len(set(self._enable_field).intersection(set(['learn']))) > 0: multi_gpu = self._cfg.multi_gpu self._rank = get_rank() if multi_gpu else 0 - if self._cuda: - # model.cuda() is an in-place operation. - model.cuda() + else: + self._rank = 0 + + # Auto-detect or set device + self._device_type, self._use_accelerator, self._device = auto_device_init(cfg_device, self._rank) + + # Keep backward compatibility with _cuda attribute + self._cuda = self._use_accelerator and self._device_type == 'cuda' + + # Move model to the detected/configured device + if self._use_accelerator: + move_to_device(model, self._device_type, self._rank) + + # Multi-GPU initialization + if len(set(self._enable_field).intersection(set(['learn']))) > 0: + multi_gpu = self._cfg.multi_gpu if multi_gpu: bp_update_sync = self._cfg.bp_update_sync self._bp_update_sync = bp_update_sync self._init_multi_gpu_setting(model, bp_update_sync) - else: - self._rank = 0 - if self._cuda: - # model.cuda() is an in-place operation. - model.cuda() + self._model = model - self._device = 'cuda:{}'.format(self._rank % torch.cuda.device_count()) if self._cuda else 'cpu' else: self._cuda = False self._rank = 0 diff --git a/ding/torch_utils/__init__.py b/ding/torch_utils/__init__.py index 151b4da7e1..71adf29fe5 100755 --- a/ding/torch_utils/__init__.py +++ b/ding/torch_utils/__init__.py @@ -12,3 +12,5 @@ from .dataparallel import DataParallel from .reshape_helper import fold_batch, unfold_batch, unsqueeze_repeat from .parameter import NonegativeParameter, TanhParameter +from .device_helper import get_available_device, get_device_count, move_to_device, get_device_string, \ + auto_device_init, is_npu_available, is_cuda_available diff --git a/ding/torch_utils/device_helper.py b/ding/torch_utils/device_helper.py new file mode 100644 index 0000000000..0f21b656c4 --- /dev/null +++ b/ding/torch_utils/device_helper.py @@ -0,0 +1,183 @@ +""" +Copyright 2020 Sensetime X-lab. All Rights Reserved. + +Device helper utilities for automatic detection of NPU and GPU devices. +Supports Huawei Ascend NPU (torch_npu) and NVIDIA GPU (torch.cuda). +""" + +import torch +from typing import Tuple, Optional +import logging + +# Try to import torch_npu for Huawei NPU support +try: + import torch_npu + TORCH_NPU_AVAILABLE = True +except ImportError: + TORCH_NPU_AVAILABLE = False + +logger = logging.getLogger(__name__) + + +def get_available_device() -> Tuple[str, bool]: + """ + Overview: + Automatically detect the available device (NPU or GPU or CPU). + Priority: NPU > GPU > CPU + Returns: + - device_type (:obj:`str`): Device type string, one of 'npu', 'cuda', 'cpu' + - is_accelerator (:obj:`bool`): Whether an accelerator (NPU/GPU) is available + Examples: + >>> device_type, is_accelerator = get_available_device() + >>> print(f"Using device: {device_type}") + """ + # Check for NPU first (Huawei Ascend) + if TORCH_NPU_AVAILABLE and torch.npu.is_available(): + npu_count = torch.npu.device_count() + logger.info(f"Detected {npu_count} NPU device(s), using NPU") + return 'npu', True + + # Check for CUDA GPU + if torch.cuda.is_available(): + gpu_count = torch.cuda.device_count() + logger.info(f"Detected {gpu_count} CUDA GPU device(s), using GPU") + return 'cuda', True + + # Fallback to CPU + logger.info("No NPU or GPU detected, using CPU") + return 'cpu', False + + +def get_device_count(device_type: str) -> int: + """ + Overview: + Get the number of available devices for the specified device type. + Arguments: + - device_type (:obj:`str`): Device type, one of 'npu', 'cuda', 'cpu' + Returns: + - count (:obj:`int`): Number of available devices + """ + if device_type == 'npu' and TORCH_NPU_AVAILABLE: + return torch.npu.device_count() + elif device_type == 'cuda': + return torch.cuda.device_count() + else: + return 1 # CPU always has 1 "device" + + +def move_to_device(model: torch.nn.Module, device_type: str, rank: int = 0) -> torch.nn.Module: + """ + Overview: + Move a PyTorch model to the specified device. + Supports NPU, CUDA, and CPU devices. + Arguments: + - model (:obj:`torch.nn.Module`): The model to move + - device_type (:obj:`str`): Device type, one of 'npu', 'cuda', 'cpu' + - rank (:obj:`int`): Device rank for multi-device setups + Returns: + - model (:obj:`torch.nn.Module`): The model moved to the device (in-place operation) + """ + if device_type == 'npu' and TORCH_NPU_AVAILABLE: + device_count = torch.npu.device_count() + device_id = rank % device_count if device_count > 0 else 0 + model.npu(device_id) + logger.debug(f"Moved model to NPU device {device_id}") + elif device_type == 'cuda': + device_count = torch.cuda.device_count() + device_id = rank % device_count if device_count > 0 else 0 + model.cuda(device_id) + logger.debug(f"Moved model to CUDA device {device_id}") + # CPU case: no need to move + return model + + +def get_device_string(device_type: str, rank: int = 0) -> str: + """ + Overview: + Get the device string for PyTorch tensor operations. + Arguments: + - device_type (:obj:`str`): Device type, one of 'npu', 'cuda', 'cpu' + - rank (:obj:`int`): Device rank for multi-device setups + Returns: + - device_str (:obj:`str`): Device string like 'npu:0', 'cuda:0', or 'cpu' + """ + if device_type in ['npu', 'cuda']: + device_count = get_device_count(device_type) + device_id = rank % device_count if device_count > 0 else 0 + return f'{device_type}:{device_id}' + else: + return 'cpu' + + +def auto_device_init(cfg_device: Optional[str], rank: int = 0) -> Tuple[str, bool, str]: + """ + Overview: + Initialize device settings based on config. + Supports automatic detection, explicit device type, or legacy 'cuda' boolean. + Arguments: + - cfg_device (:obj:`Optional[str]`): Device configuration from config. + Can be 'auto', 'npu', 'cuda', 'cpu', or None (defaults to 'auto') + - rank (:obj:`int`): Device rank for multi-device setups + Returns: + - device_type (:obj:`str`): Detected device type ('npu', 'cuda', or 'cpu') + - use_accelerator (:obj:`bool`): Whether an accelerator is being used + - device_str (:obj:`str`): Full device string for PyTorch operations + Examples: + >>> device_type, use_accelerator, device_str = auto_device_init('auto') + >>> # Returns ('npu', True, 'npu:0') if NPU available + >>> # Returns ('cuda', True, 'cuda:0') if GPU available + >>> # Returns ('cpu', False, 'cpu') otherwise + """ + # Default to auto detection if not specified + if cfg_device is None or cfg_device == 'auto': + device_type, use_accelerator = get_available_device() + else: + # Explicit device type specified + device_type = cfg_device.lower() + + # Validate the device type is available + if device_type == 'npu': + if TORCH_NPU_AVAILABLE and torch.npu.is_available(): + use_accelerator = True + logger.info("Using NPU as explicitly configured") + else: + logger.warning("NPU requested but not available, falling back to CPU") + device_type = 'cpu' + use_accelerator = False + elif device_type == 'cuda': + if torch.cuda.is_available(): + use_accelerator = True + logger.info("Using CUDA GPU as explicitly configured") + else: + logger.warning("CUDA requested but not available, falling back to CPU") + device_type = 'cpu' + use_accelerator = False + else: + # CPU or any other value + device_type = 'cpu' + use_accelerator = False + logger.info("Using CPU as configured") + + device_str = get_device_string(device_type, rank) + + return device_type, use_accelerator, device_str + + +def is_npu_available() -> bool: + """ + Overview: + Check if Huawei NPU is available. + Returns: + - available (:obj:`bool`): True if NPU is available + """ + return TORCH_NPU_AVAILABLE and torch.npu.is_available() + + +def is_cuda_available() -> bool: + """ + Overview: + Check if NVIDIA CUDA GPU is available. + Returns: + - available (:obj:`bool`): True if CUDA is available + """ + return torch.cuda.is_available() diff --git a/dizoo/classic_control/cartpole/config/cartpole_ppo_config.py b/dizoo/classic_control/cartpole/config/cartpole_ppo_config.py index 4c1333bbc8..21385d2f26 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_ppo_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_ppo_config.py @@ -9,7 +9,7 @@ stop_value=195, ), policy=dict( - cuda=False, + device='auto', # Auto-detect NPU > GPU > CPU action_space='discrete', model=dict( obs_shape=4, From fbab0b917b28587f145f3d286df1659a37b2f81e Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Fri, 6 Feb 2026 04:00:25 +0800 Subject: [PATCH 2/3] feature(pu): adapt to npu --- ding/policy/base_policy.py | 3 ++- ding/utils/default_helper.py | 20 +++++++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index 34218422f3..85c124baf7 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -162,7 +162,8 @@ def __init__( self._device_type, self._use_accelerator, self._device = auto_device_init(cfg_device, self._rank) # Keep backward compatibility with _cuda attribute - self._cuda = self._use_accelerator and self._device_type == 'cuda' + # Set _cuda=True for ANY accelerator (GPU or NPU) to ensure data transfer logic works + self._cuda = self._use_accelerator # Move model to the detected/configured device if self._use_accelerator: diff --git a/ding/utils/default_helper.py b/ding/utils/default_helper.py index 1881ca6cc0..bf52f71feb 100644 --- a/ding/utils/default_helper.py +++ b/ding/utils/default_helper.py @@ -7,6 +7,13 @@ import torch import treetensor.torch as ttorch +# Try to import torch_npu for Huawei NPU support +try: + import torch_npu + TORCH_NPU_AVAILABLE = True +except ImportError: + TORCH_NPU_AVAILABLE = False + def get_shape0(data: Union[List, Dict, torch.Tensor, ttorch.Tensor]) -> int: """ @@ -418,7 +425,7 @@ def set_pkg_seed(seed: int, use_cuda: bool = True) -> None: This is usaually used in entry scipt in the section of setting random seed for all package and instance Argument: - seed(:obj:`int`): Set seed - - use_cuda(:obj:`bool`) Whether use cude + - use_cuda(:obj:`bool`) Whether use cuda or other accelerators (NPU/GPU) Examples: >>> # ../entry/xxxenv_xxxpolicy_main.py >>> ... @@ -434,8 +441,15 @@ def set_pkg_seed(seed: int, use_cuda: bool = True) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - if use_cuda and torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + + # Set seed for accelerators (GPU or NPU) + if use_cuda: + # Set CUDA seed if available + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + # Set NPU seed if available + if TORCH_NPU_AVAILABLE and torch.npu.is_available(): + torch.npu.manual_seed(seed) @lru_cache() From 46ca005877f1ec971ddc00e2a95ba6d5096ee133 Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Fri, 6 Feb 2026 04:06:41 +0800 Subject: [PATCH 3/3] polish(pu): add device logs --- ding/policy/base_policy.py | 13 ++++++ ding/torch_utils/device_helper.py | 67 +++++++++++++++++++++++++------ ding/utils/default_helper.py | 5 +++ 3 files changed, 72 insertions(+), 13 deletions(-) diff --git a/ding/policy/base_policy.py b/ding/policy/base_policy.py index 85c124baf7..5b4f26c007 100644 --- a/ding/policy/base_policy.py +++ b/ding/policy/base_policy.py @@ -169,6 +169,19 @@ def __init__( if self._use_accelerator: move_to_device(model, self._device_type, self._rank) + # Print final device configuration summary + print(f"\n{'='*70}") + print(f"šŸŽ‰ [DI-engine Policy] Device Setup Complete") + print(f"{'='*70}") + print(f" Policy Type: {self.__class__.__name__}") + print(f" Device Type: {self._device_type.upper()}") + print(f" Device String: {self._device}") + print(f" Using Accelerator: {self._use_accelerator}") + print(f" Rank: {self._rank}") + print(f" Multi-GPU: {self._cfg.multi_gpu if hasattr(self._cfg, 'multi_gpu') else False}") + print(f" Legacy _cuda flag: {self._cuda}") + print(f"{'='*70}\n") + # Multi-GPU initialization if len(set(self._enable_field).intersection(set(['learn']))) > 0: multi_gpu = self._cfg.multi_gpu diff --git a/ding/torch_utils/device_helper.py b/ding/torch_utils/device_helper.py index 0f21b656c4..f124987ed2 100644 --- a/ding/torch_utils/device_helper.py +++ b/ding/torch_utils/device_helper.py @@ -31,20 +31,42 @@ def get_available_device() -> Tuple[str, bool]: >>> device_type, is_accelerator = get_available_device() >>> print(f"Using device: {device_type}") """ + print("\n" + "="*70) + print("šŸ” [DI-engine] Device Detection") + print("="*70) + # Check for NPU first (Huawei Ascend) - if TORCH_NPU_AVAILABLE and torch.npu.is_available(): - npu_count = torch.npu.device_count() - logger.info(f"Detected {npu_count} NPU device(s), using NPU") - return 'npu', True + if TORCH_NPU_AVAILABLE: + print("āœ“ torch_npu module is installed") + if torch.npu.is_available(): + npu_count = torch.npu.device_count() + print(f"āœ“ NPU is available: {npu_count} device(s) detected") + print(f"āœ“ NPU device names: {[torch.npu.get_device_name(i) for i in range(npu_count)]}") + print(f"šŸŽÆ Selected device: NPU") + print("="*70 + "\n") + logger.info(f"[Device] Using NPU with {npu_count} device(s)") + return 'npu', True + else: + print("āœ— NPU is not available") + else: + print("āœ— torch_npu module is not installed") # Check for CUDA GPU if torch.cuda.is_available(): gpu_count = torch.cuda.device_count() - logger.info(f"Detected {gpu_count} CUDA GPU device(s), using GPU") + print(f"āœ“ CUDA is available: {gpu_count} device(s) detected") + print(f"āœ“ GPU device names: {[torch.cuda.get_device_name(i) for i in range(gpu_count)]}") + print(f"šŸŽÆ Selected device: CUDA GPU") + print("="*70 + "\n") + logger.info(f"[Device] Using CUDA GPU with {gpu_count} device(s)") return 'cuda', True + else: + print("āœ— CUDA is not available") # Fallback to CPU - logger.info("No NPU or GPU detected, using CPU") + print("šŸŽÆ Selected device: CPU (no accelerator detected)") + print("="*70 + "\n") + logger.info("[Device] Using CPU (no accelerator available)") return 'cpu', False @@ -80,13 +102,18 @@ def move_to_device(model: torch.nn.Module, device_type: str, rank: int = 0) -> t if device_type == 'npu' and TORCH_NPU_AVAILABLE: device_count = torch.npu.device_count() device_id = rank % device_count if device_count > 0 else 0 + print(f"šŸ“¦ [DI-engine] Moving model to NPU device {device_id} (rank={rank})") model.npu(device_id) - logger.debug(f"Moved model to NPU device {device_id}") + logger.info(f"[Device] Model moved to NPU device {device_id}") elif device_type == 'cuda': device_count = torch.cuda.device_count() device_id = rank % device_count if device_count > 0 else 0 + print(f"šŸ“¦ [DI-engine] Moving model to CUDA device {device_id} (rank={rank})") model.cuda(device_id) - logger.debug(f"Moved model to CUDA device {device_id}") + logger.info(f"[Device] Model moved to CUDA device {device_id}") + else: + print(f"šŸ“¦ [DI-engine] Model will stay on CPU") + logger.info("[Device] Model stays on CPU") # CPU case: no need to move return model @@ -128,38 +155,52 @@ def auto_device_init(cfg_device: Optional[str], rank: int = 0) -> Tuple[str, boo >>> # Returns ('cuda', True, 'cuda:0') if GPU available >>> # Returns ('cpu', False, 'cpu') otherwise """ + print(f"\nāš™ļø [DI-engine] Device Configuration: cfg_device='{cfg_device}', rank={rank}") + # Default to auto detection if not specified if cfg_device is None or cfg_device == 'auto': + print(f"šŸ”§ [DI-engine] Using auto-detection mode") device_type, use_accelerator = get_available_device() else: # Explicit device type specified device_type = cfg_device.lower() + print(f"šŸ”§ [DI-engine] Explicit device type requested: '{device_type}'") # Validate the device type is available if device_type == 'npu': if TORCH_NPU_AVAILABLE and torch.npu.is_available(): use_accelerator = True - logger.info("Using NPU as explicitly configured") + npu_count = torch.npu.device_count() + print(f"āœ“ NPU requested and available: {npu_count} device(s)") + logger.info(f"[Device] Using NPU as explicitly configured ({npu_count} device(s))") else: - logger.warning("NPU requested but not available, falling back to CPU") + print(f"āš ļø NPU requested but not available, falling back to CPU") + logger.warning("[Device] NPU requested but not available, falling back to CPU") device_type = 'cpu' use_accelerator = False elif device_type == 'cuda': if torch.cuda.is_available(): use_accelerator = True - logger.info("Using CUDA GPU as explicitly configured") + gpu_count = torch.cuda.device_count() + print(f"āœ“ CUDA requested and available: {gpu_count} device(s)") + logger.info(f"[Device] Using CUDA GPU as explicitly configured ({gpu_count} device(s))") else: - logger.warning("CUDA requested but not available, falling back to CPU") + print(f"āš ļø CUDA requested but not available, falling back to CPU") + logger.warning("[Device] CUDA requested but not available, falling back to CPU") device_type = 'cpu' use_accelerator = False else: # CPU or any other value device_type = 'cpu' use_accelerator = False - logger.info("Using CPU as configured") + print(f"āœ“ Using CPU as configured") + logger.info("[Device] Using CPU as configured") device_str = get_device_string(device_type, rank) + print(f"āœ… [DI-engine] Device initialized: type={device_type}, accelerator={use_accelerator}, device_string='{device_str}'") + print("="*70 + "\n") + return device_type, use_accelerator, device_str diff --git a/ding/utils/default_helper.py b/ding/utils/default_helper.py index bf52f71feb..4fed0cf60f 100644 --- a/ding/utils/default_helper.py +++ b/ding/utils/default_helper.py @@ -438,18 +438,23 @@ def set_pkg_seed(seed: int, use_cuda: bool = True) -> None: >>> ... """ + print(f"\n🌱 [DI-engine] Setting random seed: {seed}") random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) + print(f" āœ“ Set seed for: random, numpy, torch") # Set seed for accelerators (GPU or NPU) if use_cuda: # Set CUDA seed if available if torch.cuda.is_available(): torch.cuda.manual_seed(seed) + print(f" āœ“ Set CUDA seed: {seed}") # Set NPU seed if available if TORCH_NPU_AVAILABLE and torch.npu.is_available(): torch.npu.manual_seed(seed) + print(f" āœ“ Set NPU seed: {seed}") + print() @lru_cache()