Skip to content
Draft
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8ceb830
introduce device.py and set up branch
Apr 27, 2026
ba6de9b
const parameter for host device
Apr 28, 2026
6a9fd21
include gpu demo and update petsc version as 3.24.5 misaligned for PC…
Apr 28, 2026
fb3d0b7
noting areas for change
Apr 28, 2026
e39f215
introducing context variable and lazy cupy
Apr 28, 2026
2a54e84
lazy evaluation of arrays
Apr 29, 2026
98d6010
passes basic script functionality
Apr 29, 2026
b044c7a
tofix: revised approach ensuring explicit choice of GPU device
Apr 29, 2026
1de6320
implicit transfer and defaultdict implementation (pub/sub eager copy …
Apr 30, 2026
6a26334
explicit check if GPU available on init
Apr 30, 2026
8d1f967
cudagpu and fix incoming re: remove eager copying/register & dev syncing
Apr 30, 2026
9a729e9
move conversion logic to device.py
May 1, 2026
21bac65
managing buffer duplicate
May 1, 2026
67fe76a
cleanup unnecessary todos/notes
May 1, 2026
1568315
removing notes and cleaning
May 4, 2026
2517994
fix: added copy to avoid weak reference
May 4, 2026
9e7c6ad
test: data_wo access works in context
May 4, 2026
cb5f28e
add flatten from prev logic
May 4, 2026
36d2b07
context function as global function and def state
May 4, 2026
e5a0107
fix: maintaining constant array property
May 4, 2026
39d0acd
pr review: removing unused variables
May 4, 2026
e48d698
remove dispatch to allow no-import cupy
May 4, 2026
7d582d6
pr: fix property, duplicate, init
May 4, 2026
0528e76
fix: change petsc config version to v3.25.0
May 4, 2026
e07e6c6
include cupy callable pointer
May 6, 2026
0aea6d0
maintain CPU SF comms and cond exec with cupy
May 6, 2026
6f64d83
basic GPU unit tests covering context
May 6, 2026
c8db3e2
test cases for gpu - no fixtures due to bug
May 6, 2026
7aa8425
fixed bug by amending record_modified state
May 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -3772,7 +3772,7 @@ def __init__(self, coordinates):

self._bounding_box_coords = None
self._spatial_index = None
self._saved_coordinate_dat_version = coordinates.dat.buffer.state
self._saved_coordinate_dat_version = coordinates.dat.buffer.state.copy()

self._cache = {}

Expand Down Expand Up @@ -4009,7 +4009,7 @@ def spatial_index(self):
# Build spatial index
with PETSc.Log.Event("spatial_index_build"):
self._spatial_index = spatialindex.from_regions(coords_min, coords_max)
self._saved_coordinate_dat_version = self.coordinates.dat.buffer.state
self._saved_coordinate_dat_version = self.coordinates.dat.buffer.state.copy()
return self._spatial_index

@PETSc.Log.EventDecorator()
Expand Down
5 changes: 5 additions & 0 deletions pyop3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def _init_likwid():
exscan,
AssignmentType,
)
from pyop3.device import ( # noqa: F401
HOST_DEVICE,
CUDAGPU,
offloading
)
from pyop3.sf import StarForest, single_star_sf, local_sf
import pyop3.sf
from pyop3.tree.index_tree.parse import as_index_forest
Expand Down
76 changes: 61 additions & 15 deletions pyop3/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
from pyop3.dtypes import IntType, ScalarType, DTypeT
from pyop3.sf import DistributedObject, NullStarForest, StarForest, local_sf
from pyop3.utils import UniqueNameGenerator, as_tuple, deprecated, maybe_generate_name, readonly
from pyop3.device import (
Device,
CUDAGPU,
CPU,
HOST_DEVICE,
_current_device
)

from ._buffer_cy import set_petsc_mat_diagonal

Expand Down Expand Up @@ -222,20 +229,23 @@ def handle(self, *, nest_indices: tuple[tuple[int, ...], ...] = ()) -> Any:
# copies should live in this class.
@pyop3.record.record()
class ArrayBuffer(AbstractArrayBuffer, ConcreteBuffer):
"""A buffer whose underlying data structure is a numpy array."""
"""A buffer whose underlying data structure is a lazily-evaluated NumPy/CuPy array."""

# {{{ Instance attrs

_lazy_data: np.ndarray = dataclasses.field(repr=False)
_lazy_data: dict[Device, np.ndarray | cp.ndarray] = dataclasses.field(repr=False)
sf: StarForest
_name: str
_constant: bool
_rank_equal: bool
_ordered: bool

# TODO: Connor and I both dislike defaultdict but I can't think of an alternative atm
_state: collections.defaultdict[Device, int]
_last_updated_device: Device
Comment thread
SamSJackson marked this conversation as resolved.
Outdated

_max_value: np.number | None = None

_state: int = 0

# flags for tracking parallel correctness
_leaves_valid: bool = True
Expand All @@ -247,8 +257,8 @@ def instruction_executor_cache_key(self, buffer_counter: Mapping[AbstractBuffer,
type(self), self._constant, self._rank_equal, self._ordered,
self.dtype, buffer_counter[self])

def __init__(self, data: np.ndarray, sf: StarForest | None = None, *, name: str|None=None,prefix:str|None=None,constant:bool=False, rank_equal: bool = False, max_value: numbers.Number | None=None, ordered:bool=False):
data = data.flatten()
def __init__(self, data: np.ndarray | cp.ndarray, sf: StarForest | None = None, *, name: str|None=None,prefix:str|None=None,constant:bool=False, rank_equal: bool = False, max_value: numbers.Number | None=None, ordered:bool=False):
Comment thread
connorjward marked this conversation as resolved.
Outdated

if sf is None:
sf = NullStarForest(data.size)
name = utils.maybe_generate_name(name, prefix, self.DEFAULT_PREFIX)
Expand All @@ -258,16 +268,25 @@ def __init__(self, data: np.ndarray, sf: StarForest | None = None, *, name: str|
if rank_equal and not constant:
raise ValueError

if constant:
data.flags.writeable = False
Comment thread
SamSJackson marked this conversation as resolved.
data = data.flatten()
ctx = self.get_context()
data_mapping = {ctx: ctx.asarray(data)}

self._lazy_data = data
self._lazy_data = data_mapping
self.sf = sf
self._name = name
self._constant = constant
self._rank_equal = rank_equal
self._max_value = max_value
self._ordered = ordered
self._last_updated_device = ctx

self._state = collections.defaultdict(int, [(ctx, 0)])
Comment thread
SamSJackson marked this conversation as resolved.
Outdated

# TODO: CuPy has no support for `writeable` flag
if constant and isinstance(self._data, np.ndarray):
self._data.flags.writeable = False

self.__post_init__()

def __post_init__(self) -> None:
Expand All @@ -276,7 +295,7 @@ def __post_init__(self) -> None:
assert self.constant
if self.ordered:
utils.debug_assert(lambda: utils.is_sorted(self._lazy_data))
if self.constant:
if self.constant and isinstance(self._data, np.ndarray):
assert not self._data.flags.writeable

# }}}
Expand Down Expand Up @@ -305,19 +324,24 @@ def dtype(self) -> np.dtype:
return self._data.dtype

def inc_state(self) -> None:
self._state += 1
ctx = self.get_context()
self._state[ctx] += 1
self._last_updated_device = ctx

def duplicate(self, *, copy: bool = False) -> ArrayBuffer:
# make sure that there are no pending transfers before we copy
self.assemble()
name = f"{self.name}_copy"
if copy:
data = self._lazy_data.copy()
data = {obj: arr.copy() for obj, arr in self._lazy_data.items()}
Comment thread
SamSJackson marked this conversation as resolved.
Outdated
else:
data = np.zeros_like(self._lazy_data)
data = {obj: obj.zeros_like(arr) for obj, arr in self._lazy_data.items()}
Comment thread
connorjward marked this conversation as resolved.
Outdated
return self.__record_init__(_name=name, _lazy_data=data)

is_nested: ClassVar[bool] = False

def get_context(self) -> Device:
Comment thread
SamSJackson marked this conversation as resolved.
Outdated
return _current_device.get()

@property
def handle(self) -> np.ndarray:
Expand Down Expand Up @@ -466,11 +490,18 @@ def leaves_valid(self) -> bool:

@property
def _data(self):
if self._lazy_data is None:
self._lazy_data = np.zeros(self.shape, dtype=self.dtype)
ctx = self.get_context()

if not self._is_data_available(ctx) or not self._is_data_synced(ctx):
self.sync_devices(ctx)

# NOTE: If data is None, set to zeros?
# if self._lazy_data is None:
# self._lazy_data = np.zeros(self.shape, dtype=self.dtype)

if self.name == "array_247_buffer":
breakpoint()
Comment thread
SamSJackson marked this conversation as resolved.
Outdated
return self._lazy_data
return self._lazy_data[ctx]

# TODO: I think the halo bits should only be handled at the Dat level via the
# axis tree. Here we can just consider the array.
Expand Down Expand Up @@ -583,7 +614,22 @@ def localize(self) -> ArrayBuffer:
@cached_property
def _localized(self) -> ArrayBuffer:
return self.__record_init__(sf=None)

def sync_devices(self, current_device: Device):
last_updated_device = self._last_updated_device

self._lazy_data[current_device] = current_device.asarray(self._lazy_data[last_updated_device])
self._state[current_device] = self._state[last_updated_device]

# NOTE: Current fix for CuPy having no `writeable support` or maintaining flags
if self.constant and isinstance(self._lazy_data[current_device], np.ndarray):
self._lazy_data[current_device].flags.writeable = False

def _is_data_available(self, device: Device) -> bool:
return device in self._lazy_data

def _is_data_synced(self, device: Device) -> bool:
return self.state[device] == max(self.state.values())

class MatBufferSpec(abc.ABC):
pass
Expand Down
97 changes: 97 additions & 0 deletions pyop3/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# File to handle op3.device context manager
from abc import ABCMeta, abstractmethod
import contextlib
import contextvars
import warnings

import numpy as np

class Device(metaclass=ABCMeta):
_name: str
Comment thread
SamSJackson marked this conversation as resolved.
Outdated
_device_index: int | None
Comment thread
SamSJackson marked this conversation as resolved.
Outdated

def __init__(self, device_index: int | None = None):
pass

@property
def name(self):
return self._name

@property
def device_index(self):
return self._device_index

@abstractmethod
def asarray(self, arr):
pass

@abstractmethod
def zeros_like(self, arr):
pass

def __repr__(self):
return self._name

def __str__(self):
return self._name

class CPU(Device):

def __init__(self, device_index: int | None = None):
super().__init__()
self._name = "cpu"
self._registered_arrays = set()
Comment thread
SamSJackson marked this conversation as resolved.
Outdated
self._device_index = device_index

def asarray(self, arr):
# NOTE: Better logic needed if we switch from just NumPy/CuPy
if not isinstance(arr, np.ndarray):
import cupy as cp
return cp.asnumpy(arr)

return np.array(arr)

def zeros_like(self, arr):
return np.zeros_like(arr)

class CUDAGPU(Device):

def __init__(self, device_index: int | None = None):
super().__init__()
self._name = "CudaGPU"
Comment thread
SamSJackson marked this conversation as resolved.
Outdated
self._registered_arrays = set()
Comment thread
SamSJackson marked this conversation as resolved.
Outdated
self._token = None
self._device_index = device_index

try:
import cupy as cp
assert cp.is_available()
except:
# TODO: Raise No GPU exception
raise NotImplementedError

def asarray(self, arr):
import cupy as cp
return cp.asarray(arr)

def zeros_like(self, arr):
import cupy as cp
return cp.zeros_like(arr)

@contextlib.contextmanager
def offloading(dev: Device):
# TODO: Not Device exception
if not isinstance(dev, Device):
raise NotImplementedError

token = _current_device.set(dev)
try:
yield
finally:
_current_device.reset(token)

# NOTE: Should this const variable be here?
Comment thread
SamSJackson marked this conversation as resolved.
Outdated
HOST_DEVICE = CPU()

# NOTE: Use contextvars to act as a bridge between buffer and manager
_current_device = contextvars.ContextVar("current_device", default=HOST_DEVICE)
Comment thread
SamSJackson marked this conversation as resolved.
Outdated
3 changes: 3 additions & 0 deletions pyop3/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class ValueMismatchException(Pyop3Exception):
class UnhashableObjectException(Pyop3Exception, TypeError):
pass

class UnsupportedArrayException(Pyop3Exception, TypeError):
pass


# {{{ caching

Expand Down
29 changes: 25 additions & 4 deletions pyop3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@
from pyop3.config import config
from pyop3.constants import PYOP3_DECIDE, _nothing
from pyop3.dtypes import DTypeT, IntType
from pyop3.exceptions import CommMismatchException, CommNotFoundException, Pyop3Exception, UnhashableObjectException
from pyop3.exceptions import CommMismatchException, CommNotFoundException, Pyop3Exception, UnhashableObjectException, UnsupportedArrayException
from pyop3.mpi import collective

ndarray_types = [np.ndarray,]
try:
import cupy as cp
ndarray_types = [np.ndarray, cp.ndarray]
except ImportError:
pass


class UniqueNameGenerator(pytools.UniqueNameGenerator):
"""Class for generating unique names."""
Expand Down Expand Up @@ -169,7 +176,7 @@ def is_sequence(item):

def flatten(iterable):
"""Recursively flatten a nested iterable."""
if isinstance(iterable, np.ndarray):
if isinstance(iterable, tuple(ndarray_types)):
return iterable.flatten()
if not isinstance(iterable, (list, tuple)):
return (iterable,)
Expand Down Expand Up @@ -331,12 +338,22 @@ def map_when(func, when_func, iterable):
yield item


def readonly(array):
@functools.singledispatch
def readonly(array: Any) -> Any:
raise UnsupportedArrayException

@readonly.register
def _(array: np.ndarray) -> np.ndarray:
"""Return a readonly view of a numpy array."""
view = array.view()
view.setflags(write=False)
return view

@readonly.register
def _(array: cp.ndarray) -> cp.ndarray:
Comment thread
SamSJackson marked this conversation as resolved.
Outdated
""" Return a view of a CuPy array."""
view = array.view()
return view

def debug_assert(predicate, msg=None):
if config.debug_checks:
Expand Down Expand Up @@ -576,8 +593,12 @@ def pretty_type(obj: Any) -> str:


def safe_equals(a, b, /) -> bool:
if any(isinstance(x, np.ndarray) for x in [a, b]):
if any(isinstance(x, tuple(ndarray_types)) for x in [a, b]):
return (a == b).all()
if any(isinstance(x, dict) for x in [a, b]):
if a.keys() != b.keys():
return False
return all(safe_equals(a[k], b[k]) for k in a)
else:
return bool(a == b)

Expand Down
Loading
Loading