Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3826.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added a `subchunk_write_order` option to `ShardingCodec` to allow for `morton`, `unordered`, `lexicographic`, and `colexicographic` subchunk orderings.
7 changes: 7 additions & 0 deletions docs/user-guide/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ bytes within chunks of an array may improve the compression ratio, depending on
the structure of the data, the compression algorithm used, and which compression
filters (e.g., byte-shuffle) have been applied.

### Subchunk memory layout

The order of chunks **within each shard** can be changed via the `subchunk_write_order` parameter of the `ShardingCodec`. That parameter is a string which must be one of `["morton", "lexicographic", "colexicographic", "unordered"]`.

By default [`morton`](https://en.wikipedia.org/wiki/Z-order_curve) order provides good spatial locality however [`lexicographic` (i.e., row-major)](https://en.wikipedia.org/wiki/Row-_and_column-major_order), for example, may be better suited to "batched" workflows where some form of sequential reading through a fixed number of outer dimensions is desired. The options are `lexicographic`, `morton`, `unordered` (i.e., random), and `colexicographic`.


### Empty chunks

It is possible to configure how Zarr handles the storage of chunks that are "empty"
Expand Down
3 changes: 2 additions & 1 deletion src/zarr/codecs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
Zlib,
Zstd,
)
from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation
from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation, SubchunkWriteOrder
from zarr.codecs.transpose import TransposeCodec
from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec
from zarr.codecs.zstd import ZstdCodec
Expand All @@ -43,6 +43,7 @@
"GzipCodec",
"ShardingCodec",
"ShardingCodecIndexLocation",
"SubchunkWriteOrder",
"TransposeCodec",
"VLenBytesCodec",
"VLenUTF8Codec",
Expand Down
52 changes: 43 additions & 9 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from enum import Enum
from functools import lru_cache
from operator import itemgetter
from typing import TYPE_CHECKING, Any, NamedTuple, cast
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, cast

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -47,7 +47,6 @@
BasicIndexer,
ChunkProjection,
SelectorTuple,
_morton_order,
_morton_order_keys,
c_order_iter,
get_indexer,
Expand All @@ -59,7 +58,7 @@

if TYPE_CHECKING:
from collections.abc import Iterator
from typing import Self
from typing import Final, Self

from zarr.core.common import JSON
from zarr.core.dtype.wrapper import TBaseDType, TBaseScalar, ZDType
Expand All @@ -78,6 +77,15 @@ class ShardingCodecIndexLocation(Enum):
end = "end"


SubchunkWriteOrder = Literal["morton", "unordered", "lexicographic", "colexicographic"]
SUBCHUNK_WRITE_ORDER: Final[tuple[str, str, str, str]] = (
Comment thread
ilan-gold marked this conversation as resolved.
"morton",
"unordered",
"lexicographic",
"colexicographic",
)


def parse_index_location(data: object) -> ShardingCodecIndexLocation:
return parse_enum(data, ShardingCodecIndexLocation)

Expand Down Expand Up @@ -305,7 +313,9 @@ class ShardingCodec(
chunk_shape: tuple[int, ...]
codecs: tuple[Codec, ...]
index_codecs: tuple[Codec, ...]
rng: np.random.Generator | None
index_location: ShardingCodecIndexLocation = ShardingCodecIndexLocation.end
subchunk_write_order: SubchunkWriteOrder = "morton"

def __init__(
self,
Expand All @@ -314,16 +324,24 @@ def __init__(
codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(),),
index_codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(), Crc32cCodec()),
index_location: ShardingCodecIndexLocation | str = ShardingCodecIndexLocation.end,
subchunk_write_order: SubchunkWriteOrder = "morton",
rng: np.random.Generator | None = None,
) -> None:
chunk_shape_parsed = parse_shapelike(chunk_shape)
codecs_parsed = parse_codecs(codecs)
index_codecs_parsed = parse_codecs(index_codecs)
index_location_parsed = parse_index_location(index_location)
if subchunk_write_order not in SUBCHUNK_WRITE_ORDER:
raise ValueError(
f"Unrecognized subchunk write order: {subchunk_write_order}. Only {SUBCHUNK_WRITE_ORDER} are allowed."
)

object.__setattr__(self, "chunk_shape", chunk_shape_parsed)
object.__setattr__(self, "codecs", codecs_parsed)
object.__setattr__(self, "index_codecs", index_codecs_parsed)
object.__setattr__(self, "index_location", index_location_parsed)
object.__setattr__(self, "subchunk_write_order", subchunk_write_order)
object.__setattr__(self, "rng", rng)

# Use instance-local lru_cache to avoid memory leaks

Expand All @@ -336,14 +354,15 @@ def __init__(

# todo: typedict return type
def __getstate__(self) -> dict[str, Any]:
return self.to_dict()
return {"rng": self.rng, **self.to_dict()}

def __setstate__(self, state: dict[str, Any]) -> None:
config = state["configuration"]
object.__setattr__(self, "chunk_shape", parse_shapelike(config["chunk_shape"]))
object.__setattr__(self, "codecs", parse_codecs(config["codecs"]))
object.__setattr__(self, "index_codecs", parse_codecs(config["index_codecs"]))
object.__setattr__(self, "index_location", parse_index_location(config["index_location"]))
object.__setattr__(self, "rng", state["rng"])

# Use instance-local lru_cache to avoid memory leaks
# object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec))
Expand Down Expand Up @@ -523,6 +542,22 @@ async def _decode_partial_single(
else:
return out

def _subchunk_order_iter(self, chunks_per_shard: tuple[int, ...]) -> Iterable[tuple[int, ...]]:
match self.subchunk_write_order:
case "morton":
subchunk_iter = morton_order_iter(chunks_per_shard)
case "lexicographic":
subchunk_iter = np.ndindex(chunks_per_shard)
case "colexicographic":
subchunk_iter = (c[::-1] for c in np.ndindex(chunks_per_shard[::-1]))
case "unordered":
subchunk_list = list(np.ndindex(chunks_per_shard))
(self.rng if self.rng is not None else np.random.default_rng()).shuffle(
subchunk_list
)
subchunk_iter = iter(subchunk_list)
return subchunk_iter

async def _encode_single(
self,
shard_array: NDBuffer,
Expand All @@ -540,8 +575,7 @@ async def _encode_single(
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
)
)

shard_builder = dict.fromkeys(morton_order_iter(chunks_per_shard))
shard_builder = dict.fromkeys(self._subchunk_order_iter(chunks_per_shard))

await self.codec_pipeline.write(
[
Expand Down Expand Up @@ -582,7 +616,7 @@ async def _encode_partial_single(
)

if self._is_complete_shard_write(indexer, chunks_per_shard):
shard_dict = dict.fromkeys(morton_order_iter(chunks_per_shard))
shard_dict = dict.fromkeys(np.ndindex(chunks_per_shard))
Copy link
Copy Markdown
Contributor Author

@ilan-gold ilan-gold Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @mkitti

Here and below, I don't think there is any need to construct the dict in morton order, right? There should be no correctness or performance hit here?

@d-v-b This now ensures we only shuffle in the unordered case once so the test is nice and clean - write once + get order, create a new codec with the same seed + create the iterator from that codec, match orders

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Python, dicts are ordered and I think the optimal iteration order may need to be encoded in the dict the last time I examined the situation. I was just trying to preserve the situation before my edits.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

027c469

So this wasn't about dictionary order, but instead in the vectorized case, the order to ShardReader.to_dict_vectorized had to match that of what ShardReader was internally generating, as it turned out morton order. So I'm glad I caught this because I think it means the data was being corrupted for the other orders (which weren't getting hypothesis-tested).

So I'm going to add something to the hyptothesis tests for this.

I had the same feeling initially that the dictionary order mattered, but it turns out the final call to _encode_shard_dict actually handles the ordering for us to the output buffer while writing to the intermediate shard_dict can be done in any order, as long as the final buffer is done in the correct order

else:
shard_reader = await self._load_full_shard_maybe(
byte_getter=byte_setter,
Expand All @@ -592,7 +626,7 @@ async def _encode_partial_single(
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
# Use vectorized lookup for better performance
shard_dict = shard_reader.to_dict_vectorized(
np.asarray(_morton_order(chunks_per_shard))
np.array(list(np.ndindex(chunks_per_shard)))
)

await self.codec_pipeline.write(
Expand Down Expand Up @@ -631,7 +665,7 @@ async def _encode_shard_dict(

template = buffer_prototype.buffer.create_zero_length()
chunk_start = 0
for chunk_coords in morton_order_iter(chunks_per_shard):
for chunk_coords in self._subchunk_order_iter(chunks_per_shard):
value = map.get(chunk_coords)
if value is None:
continue
Expand Down
137 changes: 135 additions & 2 deletions tests/test_codecs/test_sharding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pickle
import re
from typing import Any
from typing import Any, get_args

import numpy as np
import numpy.typing as npt
Expand All @@ -13,12 +13,15 @@
from zarr.abc.store import Store
from zarr.codecs import (
BloscCodec,
BytesCodec,
Crc32cCodec,
ShardingCodec,
ShardingCodecIndexLocation,
TransposeCodec,
)
from zarr.codecs.sharding import SubchunkWriteOrder, _ShardReader
from zarr.core.buffer import NDArrayLike, default_buffer_prototype
from zarr.storage import StorePath, ZipStore
from zarr.storage import MemoryStore, StorePath, ZipStore

from ..conftest import ArrayRequest
from .test_codecs import _AsyncArrayProxy, order_from_dim
Expand Down Expand Up @@ -555,3 +558,133 @@ def test_sharding_mixed_integer_list_indexing(store: Store) -> None:
s3 = sharded[0:5, 1, 0:3]
assert c3.shape == s3.shape == (5, 3) # type: ignore[union-attr]
np.testing.assert_array_equal(c3, s3)


async def stored_data_and_get_order(
codec: ShardingCodec, chunks_per_shard: tuple[int, ...]
) -> list[tuple[int, ...]]:
shard_shape = tuple(c * s for c, s in zip(chunks_per_shard, codec.chunk_shape, strict=True))
store = MemoryStore()
arr = zarr.create_array(
StorePath(store),
shape=shard_shape,
dtype="uint8",
chunks=shard_shape,
serializer=codec,
filters=None,
compressors=None,
fill_value=0,
)

arr[:] = np.arange(np.prod(shard_shape), dtype="uint8").reshape(shard_shape)

shard_buf = await store.get("c/0/0", prototype=default_buffer_prototype())
if shard_buf is None:
raise RuntimeError("data write failed")
index = (await _ShardReader.from_bytes(shard_buf, codec, chunks_per_shard)).index
offset_to_coord: dict[int, tuple[int, ...]] = dict(
zip(
index.get_chunk_slices_vectorized(np.array(list(np.ndindex(chunks_per_shard))))[
0
], # start
list(np.ndindex(chunks_per_shard)), # coord
strict=True,
)
)

# The physical write order is recovered by sorting coordinates by start offset.
return [coord for _, coord in sorted(offset_to_coord.items())]


@pytest.mark.parametrize(
"subchunk_write_order",
get_args(SubchunkWriteOrder),
)
async def test_encoded_subchunk_write_order(subchunk_write_order: SubchunkWriteOrder) -> None:
"""Subchunks must be physically laid out in the shard in the order specified by
``subchunk_write_order``. We verify this by decoding the shard index and sorting
the chunk coordinates by their byte offset."""
# Use a non-square chunks_per_shard so all three orderings are distinguishable.
chunks_per_shard = (3, 2)
chunk_shape = (4, 4)
seed = 0
codec = ShardingCodec(
chunk_shape=chunk_shape,
codecs=[BytesCodec()],
index_codecs=[BytesCodec(), Crc32cCodec()],
index_location=ShardingCodecIndexLocation.end,
subchunk_write_order=subchunk_write_order,
rng=np.random.default_rng(seed=seed),
)

actual_order = await stored_data_and_get_order(codec, chunks_per_shard)
if subchunk_write_order != "unordered":
expected_order = list(codec._subchunk_order_iter(chunks_per_shard))
assert actual_order == expected_order
else:
same_order_same_seed = list(
ShardingCodec(
chunk_shape=chunk_shape,
codecs=[BytesCodec()],
index_codecs=[BytesCodec(), Crc32cCodec()],
index_location=ShardingCodecIndexLocation.end,
subchunk_write_order=subchunk_write_order,
rng=np.random.default_rng(seed=seed),
)._subchunk_order_iter(chunks_per_shard)
)
assert actual_order == same_order_same_seed


async def test_unordered_can_be_seeded() -> None:
orders = []
chunks_per_shard = (3, 2)
chunk_shape = (4, 4)
seed = 0
for _ in range(4):
codec = ShardingCodec(
chunk_shape=chunk_shape,
codecs=[BytesCodec()],
index_codecs=[BytesCodec(), Crc32cCodec()],
index_location=ShardingCodecIndexLocation.end,
subchunk_write_order="unordered",
rng=np.random.default_rng(seed=seed),
)
# The physical write order is recovered by sorting coordinates by start offset.
orders.append(await stored_data_and_get_order(codec, chunks_per_shard))
assert all(orders[0] == o for o in orders)


@pytest.mark.parametrize(
"subchunk_write_order",
get_args(SubchunkWriteOrder),
)
@pytest.mark.parametrize("do_partial", [True, False], ids=["partial", "complete"])
def test_subchunk_write_order_roundtrip(
subchunk_write_order: SubchunkWriteOrder, do_partial: bool
) -> None:
"""Data written with any ``subchunk_write_order`` must round-trip correctly."""
chunks_per_shard = (3, 2)
chunk_shape = (4, 4)
shard_shape = tuple(c * s for c, s in zip(chunks_per_shard, chunk_shape, strict=True))
data = np.arange(np.prod(shard_shape), dtype="uint16").reshape(shard_shape)
arr = zarr.create_array(
StorePath(MemoryStore()),
shape=shard_shape,
dtype=data.dtype,
chunks=shard_shape,
serializer=ShardingCodec(
chunk_shape=chunk_shape,
codecs=[BytesCodec()],
subchunk_write_order=subchunk_write_order,
),
filters=None,
compressors=None,
fill_value=0,
)
if do_partial:
sub_data = data[: (shard_shape[0] // 2)]
arr[: (shard_shape[0] // 2)] = data[: (shard_shape[0] // 2)]
data = np.vstack([sub_data, np.zeros_like(sub_data)])
else:
arr[:] = data
np.testing.assert_array_equal(arr[:], data)
Loading