diff --git a/firedrake/mesh.py b/firedrake/mesh.py index eb4be68b18..cea52bbcd6 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -3772,7 +3772,7 @@ def __init__(self, coordinates): self._bounding_box_coords = None self._spatial_index = None - self._saved_coordinate_dat_version = coordinates.dat.buffer.state + self._saved_coordinate_dat_version = coordinates.dat.buffer.state.copy() self._cache = {} @@ -4009,7 +4009,7 @@ def spatial_index(self): # Build spatial index with PETSc.Log.Event("spatial_index_build"): self._spatial_index = spatialindex.from_regions(coords_min, coords_max) - self._saved_coordinate_dat_version = self.coordinates.dat.buffer.state + self._saved_coordinate_dat_version = self.coordinates.dat.buffer.state.copy() return self._spatial_index @PETSc.Log.EventDecorator() diff --git a/pyop3/__init__.py b/pyop3/__init__.py index 1cfc849fba..e53a7360e7 100644 --- a/pyop3/__init__.py +++ b/pyop3/__init__.py @@ -87,6 +87,11 @@ def _init_likwid(): exscan, AssignmentType, ) +from pyop3.device import ( # noqa: F401 + HOST_DEVICE, + CUDAGPU, + offloading +) from pyop3.sf import StarForest, single_star_sf, local_sf import pyop3.sf from pyop3.tree.index_tree.parse import as_index_forest diff --git a/pyop3/buffer.py b/pyop3/buffer.py index e03d964f32..6a89f6f8ad 100644 --- a/pyop3/buffer.py +++ b/pyop3/buffer.py @@ -20,6 +20,11 @@ from pyop3.dtypes import IntType, ScalarType, DTypeT from pyop3.sf import DistributedObject, NullStarForest, StarForest, local_sf from pyop3.utils import UniqueNameGenerator, as_tuple, deprecated, maybe_generate_name, readonly +from pyop3.device import ( + Device, + get_current_device, + on_host +) from ._buffer_cy import set_petsc_mat_diagonal @@ -56,8 +61,10 @@ def wrapper(self, *args, **kwargs): def record_modified(func): def wrapper(self, *args, **kwargs): assert not self.constant - self.inc_state() - return func(self, *args, **kwargs) + try: + return func(self, *args, **kwargs) + finally: + self.inc_state() return wrapper @@ -217,26 +224,23 @@ def handle(self, *, nest_indices: tuple[tuple[int, ...], ...] = ()) -> Any: """The underlying data structure.""" - -# NOTE: When GPU support is added, the host-device awareness and -# copies should live in this class. @pyop3.record.record() class ArrayBuffer(AbstractArrayBuffer, ConcreteBuffer): - """A buffer whose underlying data structure is a numpy array.""" + """A buffer whose underlying data structure is a lazily-evaluated NumPy/CuPy array.""" # {{{ Instance attrs - _lazy_data: np.ndarray = dataclasses.field(repr=False) + _lazy_data: dict[Device, np.ndarray | cp.ndarray] = dataclasses.field(repr=False) sf: StarForest _name: str _constant: bool _rank_equal: bool _ordered: bool + # TODO: Connor and I both dislike defaultdict but I can't think of an alternative atm + _state: collections.defaultdict[Device, int] _max_value: np.number | None = None - _state: int = 0 - # flags for tracking parallel correctness _leaves_valid: bool = True _pending_reduction: Callable | None = None @@ -247,8 +251,11 @@ def instruction_executor_cache_key(self, buffer_counter: Mapping[AbstractBuffer, type(self), self._constant, self._rank_equal, self._ordered, self.dtype, buffer_counter[self]) - def __init__(self, data: np.ndarray, sf: StarForest | None = None, *, name: str|None=None,prefix:str|None=None,constant:bool=False, rank_equal: bool = False, max_value: numbers.Number | None=None, ordered:bool=False): + def __init__(self, data: np.ndarray | cp.ndarray | None, sf: StarForest | None = None, *, name: str|None=None,prefix:str|None=None,constant:bool=False, rank_equal: bool = False, max_value: numbers.Number | None=None, ordered:bool=False): + data = data.flatten() + curr_dev = get_current_device() + if sf is None: sf = NullStarForest(data.size) name = utils.maybe_generate_name(name, prefix, self.DEFAULT_PREFIX) @@ -258,16 +265,15 @@ def __init__(self, data: np.ndarray, sf: StarForest | None = None, *, name: str| if rank_equal and not constant: raise ValueError - if constant: - data.flags.writeable = False - - self._lazy_data = data self.sf = sf self._name = name self._constant = constant self._rank_equal = rank_equal self._max_value = max_value self._ordered = ordered + self._lazy_data = {curr_dev: curr_dev.asarray(data, constant=self._constant)} + self._state = collections.defaultdict(lambda: -1, [(curr_dev, 0)]) + self.__post_init__() def __post_init__(self) -> None: @@ -276,7 +282,7 @@ def __post_init__(self) -> None: assert self.constant if self.ordered: utils.debug_assert(lambda: utils.is_sorted(self._lazy_data)) - if self.constant: + if self.constant and isinstance(self._data, np.ndarray): assert not self._data.flags.writeable # }}} @@ -304,21 +310,27 @@ def size(self) -> int: def dtype(self) -> np.dtype: return self._data.dtype + @property + def _last_updated_device(self) -> Device: + return max(self.state, key=self.state.get) + def inc_state(self) -> None: - self._state += 1 + curr_dev = get_current_device() + self.state[curr_dev] = self.state.get(curr_dev, 0) + 1 def duplicate(self, *, copy: bool = False) -> ArrayBuffer: # make sure that there are no pending transfers before we copy self.assemble() name = f"{self.name}_copy" + curr_dev = get_current_device() if copy: - data = self._lazy_data.copy() + data = {curr_dev: self._lazy_data[curr_dev]} else: - data = np.zeros_like(self._lazy_data) + data = {curr_dev: curr_dev.zeros_like(self._lazy_data[curr_dev])} return self.__record_init__(_name=name, _lazy_data=data) is_nested: ClassVar[bool] = False - + @property def handle(self) -> np.ndarray: return self._data @@ -466,11 +478,16 @@ def leaves_valid(self) -> bool: @property def _data(self): - if self._lazy_data is None: - self._lazy_data = np.zeros(self.shape, dtype=self.dtype) - if self.name == "array_247_buffer": - breakpoint() - return self._lazy_data + curr_dev = get_current_device() + + if not self._is_data_available(curr_dev) or not self._is_data_synced(curr_dev): + self.sync_devices(curr_dev) + + # NOTE: If data is None, set to zeros? + # if self._lazy_data is None: + # self._lazy_data = np.zeros(self.shape, dtype=self.dtype) + + return self._lazy_data[curr_dev] # TODO: I think the halo bits should only be handled at the Dat level via the # axis tree. Here we can just consider the array. @@ -500,6 +517,7 @@ def _reduction_ops(self): } @not_in_flight + @on_host def reduce_leaves_to_roots(self): self.reduce_leaves_to_roots_begin() self.reduce_leaves_to_roots_end() @@ -528,6 +546,7 @@ def reduce_leaves_to_roots_end(self): self._finalizer = None @not_in_flight + @on_host def broadcast_roots_to_leaves(self): self.broadcast_roots_to_leaves_begin() self.broadcast_roots_to_leaves_end() @@ -583,7 +602,18 @@ def localize(self) -> ArrayBuffer: @cached_property def _localized(self) -> ArrayBuffer: return self.__record_init__(sf=None) + + def sync_devices(self, current_device: Device): + last_updated_device = self._last_updated_device + + self._lazy_data[current_device] = current_device.asarray(self._lazy_data[last_updated_device], constant=self.constant) + self._state[current_device] = self._state[last_updated_device] + + def _is_data_available(self, device: Device) -> bool: + return device in self._lazy_data + def _is_data_synced(self, device: Device) -> bool: + return self.state[device] == max(self.state.values()) class MatBufferSpec(abc.ABC): pass diff --git a/pyop3/device.py b/pyop3/device.py new file mode 100644 index 0000000000..d4b78b5879 --- /dev/null +++ b/pyop3/device.py @@ -0,0 +1,99 @@ +# File to handle op3.device context manager +from abc import ABCMeta, abstractmethod +import contextlib +import contextvars +import warnings + +import numpy as np + +class Device(metaclass=ABCMeta): + name: str + + def __init__(self): + pass + + @abstractmethod + def asarray(self, arr, *, constant=False): + pass + + @abstractmethod + def zeros_like(self, arr): + pass + + def __repr__(self): + return self.name + + def __str__(self): + return self.name + +class CPU(Device): + name = "CPU" + + def __init__(self): + super().__init__() + + def asarray(self, arr, *, constant=False): + # NOTE: Better logic needed if we switch from just NumPy/CuPy + output = arr + if not isinstance(arr, np.ndarray): + import cupy as cp + output = cp.asnumpy(arr) + else: + output = np.array(output) + + if constant: + output.flags.writeable = False + return output + + def zeros_like(self, arr): + return np.zeros_like(arr) + +class CUDAGPU(Device): + name = "CudaGPU" + + def __init__(self): + super().__init__() + + try: + import cupy as cp + assert cp.is_available() + except: + # TODO: Raise No GPU exception + raise NotImplementedError + + def asarray(self, arr, *, constant=False): + import cupy as cp + return cp.asarray(arr) + + def zeros_like(self, arr): + import cupy as cp + return cp.zeros_like(arr) + +HOST_DEVICE = CPU() +_current_device = contextvars.ContextVar("current_device", default=HOST_DEVICE) + +@contextlib.contextmanager +def offloading(dev: Device): + # TODO: Not Device exception + if not isinstance(dev, Device): + raise NotImplementedError + + token = _current_device.set(dev) + try: + yield + finally: + _current_device.reset(token) + +def on_host(func): + + def wrapper(*args, **kwargs): + token = _current_device.set(HOST_DEVICE) + try: + return func(*args, **kwargs) + finally: + _current_device.reset(token) + + return wrapper + +def get_current_device(): + return _current_device.get() diff --git a/pyop3/exceptions.py b/pyop3/exceptions.py index 22fb6f8bf7..4a8cc5f1dc 100644 --- a/pyop3/exceptions.py +++ b/pyop3/exceptions.py @@ -24,6 +24,9 @@ class ValueMismatchException(Pyop3Exception): class UnhashableObjectException(Pyop3Exception, TypeError): pass +class UnsupportedArrayException(Pyop3Exception, TypeError): + pass + # {{{ caching diff --git a/pyop3/insn/exec.py b/pyop3/insn/exec.py index e7193dadb1..64ef1d1d93 100644 --- a/pyop3/insn/exec.py +++ b/pyop3/insn/exec.py @@ -515,6 +515,14 @@ def _(self, handle: int): # assumes an address def _(self, handle: np.ndarray) -> int: return handle.ctypes.data + try: + import cupy as cp + @_as_exec_argument.register(cp.ndarray) + def _(self, handle: cp.ndarray) -> int: + return handle.data.ptr + except ImportError: + pass + @_as_exec_argument.register def _(self, mat: PETSc.Mat) -> int: # Sometime the matrix is in an invalid state and we cannot return a handle. diff --git a/pyop3/utils.py b/pyop3/utils.py index 3fbe8778ad..b8008763af 100644 --- a/pyop3/utils.py +++ b/pyop3/utils.py @@ -23,9 +23,16 @@ from pyop3.config import config from pyop3.constants import PYOP3_DECIDE, _nothing from pyop3.dtypes import DTypeT, IntType -from pyop3.exceptions import CommMismatchException, CommNotFoundException, Pyop3Exception, UnhashableObjectException +from pyop3.exceptions import CommMismatchException, CommNotFoundException, Pyop3Exception, UnhashableObjectException, UnsupportedArrayException from pyop3.mpi import collective +ndarray_types = [np.ndarray,] +try: + import cupy as cp + ndarray_types = [np.ndarray, cp.ndarray] +except ImportError: + pass + class UniqueNameGenerator(pytools.UniqueNameGenerator): """Class for generating unique names.""" @@ -169,7 +176,7 @@ def is_sequence(item): def flatten(iterable): """Recursively flatten a nested iterable.""" - if isinstance(iterable, np.ndarray): + if isinstance(iterable, tuple(ndarray_types)): return iterable.flatten() if not isinstance(iterable, (list, tuple)): return (iterable,) @@ -330,14 +337,13 @@ def map_when(func, when_func, iterable): else: yield item - -def readonly(array): - """Return a readonly view of a numpy array.""" +def readonly(array: np.ndarray | cp.ndarray) -> np.ndarray | cp.ndarray: + """Return a readonly view of a numpy/cupy array.""" view = array.view() - view.setflags(write=False) + if isinstance(array, np.ndarray): + view.setflags(write=False) return view - def debug_assert(predicate, msg=None): if config.debug_checks: if msg: @@ -576,8 +582,12 @@ def pretty_type(obj: Any) -> str: def safe_equals(a, b, /) -> bool: - if any(isinstance(x, np.ndarray) for x in [a, b]): + if any(isinstance(x, tuple(ndarray_types)) for x in [a, b]): return (a == b).all() + if any(isinstance(x, dict) for x in [a, b]): + if a.keys() != b.keys(): + return False + return all(safe_equals(a[k], b[k]) for k in a) else: return bool(a == b) diff --git a/pyop3_gpu_demo.py b/pyop3_gpu_demo.py new file mode 100644 index 0000000000..6414718467 --- /dev/null +++ b/pyop3_gpu_demo.py @@ -0,0 +1,80 @@ +""" +Useful links: + + * https://github.com/firedrakeproject/firedrake/blob/main/.github/workflows/core.yml#L476 + + How to build a GPU-enabled Firedrake. + + * https://github.com/firedrakeproject/firedrake/blob/connorjward/pyop3-gpu/pyop3/device.py + + An implementation of the 'device' context manager. It needs a big refactor. + + * https://github.com/OP2/PyOP2/pull/691/changes#diff-f8765d963b5adb1788f453e259d8cd45f29cee9670563ddb99b9fe2bba115a12 + + Using a wrapper type to track changes between host and device. In pyop3 + this would be the 'ArrayBuffer' object and link into existing + state tracking. +""" + +import numpy as np + +from firedrake import * +import pyop3 as op3 + +from pyop3.device import on_host + + +# made up API, we need some way to identify the device +host = op3.HOST_DEVICE # or similar +gpu = op3.CUDAGPU() + +mesh = UnitSquareMesh(3, 3) +V = FunctionSpace(mesh, "P", 2) + +f = Function(V).assign(10) +g = Function(V) + +assert isinstance(f.dat.data_ro, np.ndarray) +assert isinstance(g.dat.data_ro, np.ndarray) + +# state tracking checks, .buffer.state is now device-specific +assert f.dat.buffer.state[host] == 1 # modified once +assert f.dat.buffer.state[gpu] == -1 # not created +assert g.dat.buffer.state[host] == 0 # untouched +assert g.dat.buffer.state[gpu] == -1 # not created + +with op3.offloading(gpu): + # Getting the .data attribute on the GPU should give us back a GPU array type + assert not isinstance(f.dat.data_ro, np.ndarray) + assert not isinstance(g.dat.data_ro, np.ndarray) + + # Do the assignment using array operations + g.dat.assign(2*f.dat + 3, eager=True, eager_strategy="array") + + # Do the assignment using MLIR (this is a later step) + # g.dat.assign(2*f.dat + 3, eager=True, eager_strategy="compile") + k = Function(V) + k.dat.buffer.duplicate() + k.dat.buffer.duplicate(copy=True) + + k.dat.data_rw[...] = 3 + + # state tracking checks + assert f.dat.buffer.state[host] == 1 # modified once + assert f.dat.buffer.state[gpu] == 1 # matches host + assert g.dat.buffer.state[host] == 0 # untouched + assert g.dat.buffer.state[gpu] == 1 # modified once + assert k.dat.buffer.state[host] == -1 # not created + assert k.dat.buffer.state[gpu] == 1 # modified + +assert isinstance(f.dat.data_ro, np.ndarray) +assert isinstance(g.dat.data_ro, np.ndarray) +assert (g.dat.data_ro == 23).all() + +# state tracking checks +assert f.dat.buffer.state[host] == 1 # modified once +assert f.dat.buffer.state[gpu] == 1 # matches host +assert g.dat.buffer.state[host] == 1 # matches device +assert g.dat.buffer.state[gpu] == 1 # modified once +assert k.dat.buffer.state[host] == -1 # not created +assert k.dat.buffer.state[gpu] == 1 # modified once diff --git a/pyproject.toml b/pyproject.toml index 3b2a1003e4..7b4d9ea1e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "numpy", "packaging", # TODO RELEASE - # "petsc4py==3.24.5", + # "petsc4py==3.25.0", # UNDO ME "petsctools @ git+https://github.com/firedrakeproject/petsctools.git@connorjward/cpetsc", "pkgconfig", @@ -160,7 +160,7 @@ requires = [ "mpi4py; python_version < '3.13'", "numpy", # TODO RELEASE - # "petsc4py==3.24.5", + # "petsc4py==3.25.0", "petsctools", "pkgconfig", "pybind11", diff --git a/scripts/firedrake-configure b/scripts/firedrake-configure index 0c1030808d..2199dbaa87 100755 --- a/scripts/firedrake-configure +++ b/scripts/firedrake-configure @@ -39,7 +39,7 @@ ARCH_DEFAULT = FiredrakeArch.DEFAULT ARCH_COMPLEX = FiredrakeArch.COMPLEX -SUPPORTED_PETSC_VERSION = "v3.24.5" +SUPPORTED_PETSC_VERSION = "v3.25.0" def main(): diff --git a/tests/pyop3/unit/test_gpu_context.py b/tests/pyop3/unit/test_gpu_context.py new file mode 100644 index 0000000000..69ab4f9752 --- /dev/null +++ b/tests/pyop3/unit/test_gpu_context.py @@ -0,0 +1,123 @@ +import pytest +import numpy as np + +try: + import cupy as cp +except ImportError as err: + pytest.exit("CuPy not available, skipping GPU tests...") + + +import pyop3 as op3 +from firedrake import Function, FunctionSpace, UnitSquareMesh + + +HOST = op3.HOST_DEVICE +CUDAGPU = op3.CUDAGPU() + +STATE_NOT_CREATED = -1 +STATE_UNTOUCHED = 0 +STATE_MODIFIED = 1 + +@pytest.fixture() +def mesh(): + return UnitSquareMesh(3, 3) + +@pytest.fixture() +def V(mesh): + return FunctionSpace(mesh, "P", 2) + +@pytest.fixture() +def f(V): + return Function(V) + +@pytest.fixture() +def g(V): + return Function(V) + +def state(func, device): + """Shorthand for reading buffer state on a given device.""" + return func.dat.buffer.state[device] + +class TestInitialState: + def test_host_data_is_numpy(self, f): + assert isinstance(f.dat.data_ro, np.ndarray) + + def test_host_state_modified(self, f): + """Assign affects buffer counter on host""" + f.dat.assign(10, eager=True, eager_strategy="array") + assert state(f, HOST) == 1 + + def test_gpu_state_not_created(self, f): + """CUDAGPU buffer should not exist before any offloading.""" + assert state(f, CUDAGPU) == STATE_NOT_CREATED + +# NOTE: `pytest.fixture`s not used for Offloading GPU tests due to segfault +# Unsure what is causing but we are leaving for now. +class TestOffloadingArrayTypes: + """Inside op3.offloading, data array type should be GPU array types""" + + def test_buffer_evaluates_cupy_on_cudagpu(self): + mesh = UnitSquareMesh(3, 3) + V = FunctionSpace(mesh, "P", 2) + + f = Function(V).assign(10) + g = Function(V) + with op3.offloading(CUDAGPU): + assert not isinstance(f.dat.data_ro, np.ndarray) + + def test_buffer_creation_on_cudagpu(self): + mesh = UnitSquareMesh(3, 3) + V = FunctionSpace(mesh, "P", 2) + + f = Function(V).assign(10) + g = Function(V) + with op3.offloading(CUDAGPU): + k = Function(V) + assert not isinstance(k.dat.data_ro, np.ndarray) + +class TestOffloadingAssignmentState: + + def test_host_state_untouched_after_gpu_assign(self): + """g was not modified on host""" + mesh = UnitSquareMesh(3, 3) + V = FunctionSpace(mesh, "P", 2) + + f = Function(V).assign(10) + g = Function(V) + with op3.offloading(CUDAGPU): + g.dat.assign(2 * f.dat + 3, eager=True, eager_strategy="array") + assert state(g, HOST) == 0 + + def test_gpu_state_modified_after_assign(self): + mesh = UnitSquareMesh(3, 3) + V = FunctionSpace(mesh, "P", 2) + + f = Function(V).assign(10) + g = Function(V) + with op3.offloading(CUDAGPU): + g.dat.assign(2 * f.dat + 3, eager=True, eager_strategy="array") + assert state(g, CUDAGPU) == 1 + +class TestOffloadingArraysUpdated: + + def test_gpu_array_modified(self): + '''Data on GPU is updated in GPU context''' + mesh = UnitSquareMesh(3, 3) + V = FunctionSpace(mesh, "P", 2) + + f = Function(V).assign(10) + g = Function(V) + with op3.offloading(CUDAGPU): + g.dat.assign(2 * f.dat + 3, eager=True, eager_strategy="array") + assert (g.dat.data_ro == 23).all() + + def test_gpu_array_modified(self): + ''' Data on CPU is updated when in CPU context''' + mesh = UnitSquareMesh(3, 3) + V = FunctionSpace(mesh, "P", 2) + + f = Function(V).assign(10) + g = Function(V) + with op3.offloading(CUDAGPU): + g.dat.assign(2 * f.dat + 3, eager=True, eager_strategy="array") + assert (g.dat.data_ro == 23).all()