Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3826.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added a `subchunk_write_order` option to `ShardingCodec` to allow for `morton`, `unordered`, `lexicographic`, and `colexicographic` subchunk orderings.
7 changes: 7 additions & 0 deletions docs/user-guide/performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ bytes within chunks of an array may improve the compression ratio, depending on
the structure of the data, the compression algorithm used, and which compression
filters (e.g., byte-shuffle) have been applied.

### Subchunk memory layout

The order of chunks **within each shard** can be changed via the `subchunk_write_order` parameter of the `ShardingCodec` using the `SubchunkWriteOrder` enum or a corresponding string.
Comment thread
ilan-gold marked this conversation as resolved.
Outdated

By default [`morton`](https://en.wikipedia.org/wiki/Z-order_curve) order provides good spatial locality however [`lexicographic` (i.e., row-major)](https://en.wikipedia.org/wiki/Row-_and_column-major_order), for example, may be better suited to "batched" workflows where some form of sequential reading through a fixed number of outer dimensions is desired. The options are `lexicographic`, `morton`, `unordered` (i.e., random), and `colexicographic`.


### Empty chunks

It is possible to configure how Zarr handles the storage of chunks that are "empty"
Expand Down
3 changes: 2 additions & 1 deletion src/zarr/codecs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
Zlib,
Zstd,
)
from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation
from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation, SubchunkWriteOrder
from zarr.codecs.transpose import TransposeCodec
from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec
from zarr.codecs.zstd import ZstdCodec
Expand All @@ -43,6 +43,7 @@
"GzipCodec",
"ShardingCodec",
"ShardingCodecIndexLocation",
"SubchunkWriteOrder",
"TransposeCodec",
"VLenBytesCodec",
"VLenUTF8Codec",
Expand Down
56 changes: 51 additions & 5 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import random
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
from dataclasses import dataclass, replace
from enum import Enum
Expand Down Expand Up @@ -78,10 +79,27 @@ class ShardingCodecIndexLocation(Enum):
end = "end"


class SubchunkWriteOrder(Enum):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

advantage of an enum over Literal["morton", "unordered", "lexicographic", "colexicographic"]?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just copied what was done for ShardingCodecIndexLocation!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a huge fan of enums in python (including ShardingCodecIndexingLocation), so unless you object I think it would be better to use a simple Literal + a final tuple of strings, like:

SubchunkWriteOrder = Literal["morton", "unordered", "lexicographic", "colexicographic"]
SUBCHUNK_WRITE_ORDER: Final[tuple[str, str, str, str]] = ("morton", "unordered", "lexicographic", "colexicographic")

Copy link
Copy Markdown
Contributor Author

@ilan-gold ilan-gold Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done (hopefully)!

"""
Enum for the order of the chunks within a shard.

unordered is implemented via `random.shuffle` over the lexicographic order.
"""

morton = "morton"
unordered = "unordered"
lexicographic = "lexicographic"
colexicographic = "colexicographic"


def parse_index_location(data: object) -> ShardingCodecIndexLocation:
return parse_enum(data, ShardingCodecIndexLocation)


def parse_subchunk_write_order(data: object) -> SubchunkWriteOrder:
return parse_enum(data, SubchunkWriteOrder)


@dataclass(frozen=True)
class _ShardingByteGetter(ByteGetter):
shard_dict: ShardMapping
Expand Down Expand Up @@ -306,6 +324,7 @@ class ShardingCodec(
codecs: tuple[Codec, ...]
index_codecs: tuple[Codec, ...]
index_location: ShardingCodecIndexLocation = ShardingCodecIndexLocation.end
subchunk_write_order: SubchunkWriteOrder = SubchunkWriteOrder.morton

def __init__(
self,
Expand All @@ -314,16 +333,19 @@ def __init__(
codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(),),
index_codecs: Iterable[Codec | dict[str, JSON]] = (BytesCodec(), Crc32cCodec()),
index_location: ShardingCodecIndexLocation | str = ShardingCodecIndexLocation.end,
subchunk_write_order: SubchunkWriteOrder | str = SubchunkWriteOrder.morton,
) -> None:
chunk_shape_parsed = parse_shapelike(chunk_shape)
codecs_parsed = parse_codecs(codecs)
index_codecs_parsed = parse_codecs(index_codecs)
index_location_parsed = parse_index_location(index_location)
subchunk_write_order_parsed = parse_subchunk_write_order(subchunk_write_order)

object.__setattr__(self, "chunk_shape", chunk_shape_parsed)
object.__setattr__(self, "codecs", codecs_parsed)
object.__setattr__(self, "index_codecs", index_codecs_parsed)
object.__setattr__(self, "index_location", index_location_parsed)
object.__setattr__(self, "subchunk_write_order", subchunk_write_order_parsed)

# Use instance-local lru_cache to avoid memory leaks

Expand Down Expand Up @@ -523,6 +545,31 @@ async def _decode_partial_single(
else:
return out

def _subchunk_order_iter(self, chunks_per_shard: tuple[int, ...]) -> Iterable[tuple[int, ...]]:
match self.subchunk_write_order:
case SubchunkWriteOrder.morton:
subchunk_iter = morton_order_iter(chunks_per_shard)
case SubchunkWriteOrder.lexicographic:
subchunk_iter = np.ndindex(chunks_per_shard)
case SubchunkWriteOrder.colexicographic:
subchunk_iter = (c[::-1] for c in np.ndindex(chunks_per_shard[::-1]))
case SubchunkWriteOrder.unordered:
subchunk_list = list(np.ndindex(chunks_per_shard))
random.shuffle(subchunk_list)
Comment thread
ilan-gold marked this conversation as resolved.
Outdated
subchunk_iter = iter(subchunk_list)
return subchunk_iter

def _subchunk_order_vectorized(self, chunks_per_shard: tuple[int, ...]) -> npt.NDArray[np.intp]:
match self.subchunk_write_order:
case SubchunkWriteOrder.morton:
subchunk_order_vectorized = _morton_order(chunks_per_shard)
case _:
subchunk_order_vectorized = np.fromiter(
self._subchunk_order_iter(chunks_per_shard),
dtype=np.dtype((int, len(chunks_per_shard))),
)
return subchunk_order_vectorized

async def _encode_single(
self,
shard_array: NDBuffer,
Expand All @@ -540,8 +587,7 @@ async def _encode_single(
chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape),
)
)

shard_builder = dict.fromkeys(morton_order_iter(chunks_per_shard))
shard_builder = dict.fromkeys(self._subchunk_order_iter(chunks_per_shard))

await self.codec_pipeline.write(
[
Expand Down Expand Up @@ -582,7 +628,7 @@ async def _encode_partial_single(
)

if self._is_complete_shard_write(indexer, chunks_per_shard):
shard_dict = dict.fromkeys(morton_order_iter(chunks_per_shard))
shard_dict = dict.fromkeys(self._subchunk_order_iter(chunks_per_shard))
else:
shard_reader = await self._load_full_shard_maybe(
byte_getter=byte_setter,
Expand All @@ -592,7 +638,7 @@ async def _encode_partial_single(
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
# Use vectorized lookup for better performance
shard_dict = shard_reader.to_dict_vectorized(
np.asarray(_morton_order(chunks_per_shard))
self._subchunk_order_vectorized(chunks_per_shard)
)

await self.codec_pipeline.write(
Expand Down Expand Up @@ -631,7 +677,7 @@ async def _encode_shard_dict(

template = buffer_prototype.buffer.create_zero_length()
chunk_start = 0
for chunk_coords in morton_order_iter(chunks_per_shard):
for chunk_coords in self._subchunk_order_iter(chunks_per_shard):
value = map.get(chunk_coords)
if value is None:
continue
Expand Down
97 changes: 96 additions & 1 deletion tests/test_codecs/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
from zarr.abc.store import Store
from zarr.codecs import (
BloscCodec,
BytesCodec,
Crc32cCodec,
ShardingCodec,
ShardingCodecIndexLocation,
TransposeCodec,
)
from zarr.codecs.sharding import SubchunkWriteOrder, _ShardReader
from zarr.core.buffer import NDArrayLike, default_buffer_prototype
from zarr.storage import StorePath, ZipStore
from zarr.storage import MemoryStore, StorePath, ZipStore

from ..conftest import ArrayRequest
from .test_codecs import _AsyncArrayProxy, order_from_dim
Expand Down Expand Up @@ -555,3 +558,95 @@ def test_sharding_mixed_integer_list_indexing(store: Store) -> None:
s3 = sharded[0:5, 1, 0:3]
assert c3.shape == s3.shape == (5, 3) # type: ignore[union-attr]
np.testing.assert_array_equal(c3, s3)


@pytest.mark.parametrize(
"subchunk_write_order",
list(SubchunkWriteOrder),
)
async def test_encoded_subchunk_write_order(subchunk_write_order: SubchunkWriteOrder) -> None:
"""Subchunks must be physically laid out in the shard in the order specified by
``subchunk_write_order``. We verify this by decoding the shard index and sorting
the chunk coordinates by their byte offset."""
# Use a non-square chunks_per_shard so all three orderings are distinguishable.
chunks_per_shard = (3, 2)
chunk_shape = (4, 4)
shard_shape = tuple(c * s for c, s in zip(chunks_per_shard, chunk_shape, strict=True))

codec = ShardingCodec(
chunk_shape=chunk_shape,
codecs=[BytesCodec()],
index_codecs=[BytesCodec(), Crc32cCodec()],
index_location=ShardingCodecIndexLocation.end,
subchunk_write_order=subchunk_write_order,
)
store = MemoryStore()
arr = zarr.create_array(
StorePath(store),
shape=shard_shape,
dtype="uint8",
chunks=shard_shape,
serializer=codec,
filters=None,
compressors=None,
fill_value=0,
)

arr[:] = np.arange(np.prod(shard_shape), dtype="uint8").reshape(shard_shape)

shard_buf = await store.get("c/0/0", prototype=default_buffer_prototype())
if shard_buf is None:
raise RuntimeError("data write failed")
index = (await _ShardReader.from_bytes(shard_buf, codec, chunks_per_shard)).index
offset_to_coord: dict[int, tuple[int, ...]] = dict(
zip(
index.get_chunk_slices_vectorized(np.array(list(np.ndindex(chunks_per_shard))))[
0
], # start
list(np.ndindex(chunks_per_shard)), # coord
strict=True,
)
)

# The physical write order is recovered by sorting coordinates by start offset.
actual_order = [coord for _, coord in sorted(offset_to_coord.items())]
expected_order = list(codec._subchunk_order_iter(chunks_per_shard))
assert (actual_order == expected_order) == (
subchunk_write_order != SubchunkWriteOrder.unordered
)


@pytest.mark.parametrize(
"subchunk_write_order",
list(SubchunkWriteOrder),
)
@pytest.mark.parametrize("do_partial", [True, False], ids=["partial", "complete"])
def test_subchunk_write_order_roundtrip(
subchunk_write_order: SubchunkWriteOrder, do_partial: bool
) -> None:
"""Data written with any ``subchunk_write_order`` must round-trip correctly."""
chunks_per_shard = (3, 2)
chunk_shape = (4, 4)
shard_shape = tuple(c * s for c, s in zip(chunks_per_shard, chunk_shape, strict=True))
data = np.arange(np.prod(shard_shape), dtype="uint16").reshape(shard_shape)
arr = zarr.create_array(
StorePath(MemoryStore()),
shape=shard_shape,
dtype=data.dtype,
chunks=shard_shape,
serializer=ShardingCodec(
chunk_shape=chunk_shape,
codecs=[BytesCodec()],
subchunk_write_order=subchunk_write_order,
),
filters=None,
compressors=None,
fill_value=0,
)
if do_partial:
sub_data = data[: (shard_shape[0] // 2)]
arr[: (shard_shape[0] // 2)] = data[: (shard_shape[0] // 2)]
data = np.vstack([sub_data, np.zeros_like(sub_data)])
else:
arr[:] = data
np.testing.assert_array_equal(arr[:], data)