From ce302866656593bdc497831f998853db1cff70f5 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 20 Mar 2026 11:16:07 +0100 Subject: [PATCH] fix: bring back zarr --- Cargo.toml | 2 + anndata-hdf5/Cargo.toml | 2 +- anndata-test-utils/Cargo.toml | 1 + anndata-test-utils/tests/tests.rs | 16 +++- anndata-zarr/Cargo.toml | 12 ++- anndata-zarr/src/lib.rs | 118 +++++++++++++++++++----------- 6 files changed, 102 insertions(+), 49 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 71ce30e..b596d06 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "anndata", "anndata-hdf5", + "anndata-zarr", "pyanndata", "anndata-test-utils", "python", @@ -11,4 +12,5 @@ resolver = "2" [workspace.dependencies] anndata = { path = "anndata" } anndata-hdf5 = { path = "anndata-hdf5" } +anndata-zarr = { path = "anndata-zarr" } pyanndata = { path = "pyanndata" } \ No newline at end of file diff --git a/anndata-hdf5/Cargo.toml b/anndata-hdf5/Cargo.toml index 8f3f2e0..e044377 100644 --- a/anndata-hdf5/Cargo.toml +++ b/anndata-hdf5/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "anndata-hdf5" -version = "0.5.2" +version = "0.5.1" edition = "2021" rust-version = "1.65" authors = ["Kai Zhang "] diff --git a/anndata-test-utils/Cargo.toml b/anndata-test-utils/Cargo.toml index fea6361..69ebdf7 100644 --- a/anndata-test-utils/Cargo.toml +++ b/anndata-test-utils/Cargo.toml @@ -18,6 +18,7 @@ itertools = "0.14" [dev-dependencies] anndata-hdf5 = { workspace = true } +anndata-zarr = { workspace = true } tempfile = "3.2" proptest = "1" bincode = { version = "2", features = ["serde"] } diff --git a/anndata-test-utils/tests/tests.rs b/anndata-test-utils/tests/tests.rs index ec15444..4b861b9 100644 --- a/anndata-test-utils/tests/tests.rs +++ b/anndata-test-utils/tests/tests.rs @@ -1,11 +1,13 @@ use anndata_test_utils as utils; use anndata_test_utils::with_tmp_dir; use anndata_hdf5::H5; +use anndata_zarr::Zarr; use anndata::{AnnData, Backend}; #[test] fn test_basic() { utils::test_basic::
(); + utils::test_basic::(); } #[test] @@ -15,6 +17,9 @@ fn test_complex_dataframe() { let file = dir.join("test.h5"); let adata = AnnData::
::open(H5::open(&input).unwrap()).unwrap(); adata.write::(file, None, None).unwrap(); + + let file = dir.join("test.zarr"); + adata.write::(file, None, None).unwrap(); }) } @@ -31,11 +36,10 @@ fn test_speacial_cases() { let adata_gen = || AnnData::
::new(&file).unwrap(); utils::test_speacial_cases(|| adata_gen()); - /* let file = dir.join("test.zarr"); let adata_gen = || AnnData::::new(&file).unwrap(); utils::test_speacial_cases(|| adata_gen()); - */ + }) } @@ -45,6 +49,10 @@ fn test_noncanonical() { let file = dir.join("test.h5"); let adata_gen = || AnnData::
::new(&file).unwrap(); utils::test_noncanonical(|| adata_gen()); + + let file = dir.join("test.zarr"); + let adata_gen = || AnnData::::new(&file).unwrap(); + utils::test_noncanonical(|| adata_gen()); }) } @@ -80,5 +88,9 @@ fn test_iterator() { let file = dir.join("test.h5"); let adata_gen = || AnnData::
::new(&file).unwrap(); utils::test_iterator(|| adata_gen()); + + let file = dir.join("test.zarr"); + let adata_gen = || AnnData::::new(&file).unwrap(); + utils::test_iterator(|| adata_gen()); }) } \ No newline at end of file diff --git a/anndata-zarr/Cargo.toml b/anndata-zarr/Cargo.toml index ce2fc6f..5c6f1e1 100644 --- a/anndata-zarr/Cargo.toml +++ b/anndata-zarr/Cargo.toml @@ -14,6 +14,12 @@ homepage = "https://github.com/kaizhang/anndata-rs" anndata = { workspace = true } serde_json = "1.0" anyhow = "1.0" -ndarray = { version = "0.16", features = ["serde"] } -zarrs = "0.21" -smallvec = "1.15" \ No newline at end of file +ndarray = { version = "0.17", features = ["serde"] } +zarrs = "0.23" +smallvec = "1.15" + +[dev-dependencies] +tempfile = "3.2" +proptest = "1" +rand = "0.9" +ndarray-rand = "0.16" \ No newline at end of file diff --git a/anndata-zarr/src/lib.rs b/anndata-zarr/src/lib.rs index f7e1086..c2e5f68 100644 --- a/anndata-zarr/src/lib.rs +++ b/anndata-zarr/src/lib.rs @@ -7,17 +7,27 @@ use anyhow::{bail, Context, Result}; use ndarray::{Array, ArrayD, ArrayView, CowArray, Dimension, IxDyn, SliceInfoElem}; use std::{ borrow::Cow, + num::NonZeroU64, ops::{Deref, Index}, path::{Path, PathBuf}, }; use std::{sync::Arc, vec}; -use zarrs::array::codec::bytes_to_bytes::zstd::ZstdCodec; +use zarrs::array::{ + codec::bytes_to_bytes::zstd::ZstdCodec, + data_type::{ + BoolDataType, Float32DataType, Float64DataType, Int16DataType, Int32DataType, + Int64DataType, Int8DataType, StringDataType, UInt16DataType, UInt32DataType, + UInt64DataType, UInt8DataType, + }, +}; use zarrs::filesystem::FilesystemStore; use zarrs::group::Group; use zarrs::{array::ElementOwned, storage::ReadableWritableListableStorageTraits}; use zarrs::{ - array::{codec::ShardingCodecBuilder, data_type::DataType, ArrayShardedReadableExt, Element}, - array_subset::ArraySubset, + array::{ + codec::ShardingCodecBuilder, data_type, ArrayShardedReadableExt, ArraySubset, Element, + FillValue, + }, storage::StorePrefix, }; @@ -344,20 +354,32 @@ impl AttributeOp for ZarrDataset { impl DatasetOp for ZarrDataset { fn dtype(&self) -> Result { - match self.dataset.data_type() { - DataType::UInt8 => Ok(ScalarType::U8), - DataType::UInt16 => Ok(ScalarType::U16), - DataType::UInt32 => Ok(ScalarType::U32), - DataType::UInt64 => Ok(ScalarType::U64), - DataType::Int8 => Ok(ScalarType::I8), - DataType::Int16 => Ok(ScalarType::I16), - DataType::Int32 => Ok(ScalarType::I32), - DataType::Int64 => Ok(ScalarType::I64), - DataType::Float32 => Ok(ScalarType::F32), - DataType::Float64 => Ok(ScalarType::F64), - DataType::Bool => Ok(ScalarType::Bool), - DataType::String => Ok(ScalarType::String), - ty => bail!("Unsupported type: {:?}", ty), + if self.dataset.data_type().is::() { + Ok(ScalarType::U8) + } else if self.dataset.data_type().is::() { + Ok(ScalarType::U16) + } else if self.dataset.data_type().is::() { + Ok(ScalarType::U32) + } else if self.dataset.data_type().is::() { + Ok(ScalarType::U64) + } else if self.dataset.data_type().is::() { + Ok(ScalarType::I8) + } else if self.dataset.data_type().is::() { + Ok(ScalarType::I16) + } else if self.dataset.data_type().is::() { + Ok(ScalarType::I32) + } else if self.dataset.data_type().is::() { + Ok(ScalarType::I64) + } else if self.dataset.data_type().is::() { + Ok(ScalarType::F32) + } else if self.dataset.data_type().is::() { + Ok(ScalarType::F64) + } else if self.dataset.data_type().is::() { + Ok(ScalarType::Bool) + } else if self.dataset.data_type().is::() { + Ok(ScalarType::String) + } else { + bail!("Unsupported type: {:?}", self.dataset.data_type()) } } @@ -371,7 +393,7 @@ impl DatasetOp for ZarrDataset { fn reshape(&mut self, shape: &Shape) -> Result<()> { self.dataset - .set_shape(shape.as_ref().iter().map(|x| *x as u64).collect()); + .set_shape(shape.as_ref().iter().map(|x| *x as u64).collect())?; self.dataset.store_metadata()?; Ok(()) } @@ -395,7 +417,7 @@ impl DatasetOp for ZarrDataset { .retrieve_array_subset_ndarray_sharded_opt( &dataset.cache, &subset, - &zarrs::array::codec::CodecOptions::default(), + &zarrs::array::CodecOptions::default(), )? .into_dimensionality::()?; Ok(arr) @@ -406,7 +428,7 @@ impl DatasetOp for ZarrDataset { .retrieve_array_subset_ndarray_sharded_opt( &dataset.cache, &dataset.dataset.subset_all(), - &zarrs::array::codec::CodecOptions::default(), + &zarrs::array::CodecOptions::default(), )? .into_dimensionality::()?; Ok(select(arr.view(), selection)) @@ -463,7 +485,7 @@ impl DatasetOp for ZarrDataset { if starts.len() == selection.ndim() { container .dataset - .store_array_subset_ndarray(starts.as_slice(), arr.into_owned())?; + .store_array_subset_ndarray(starts.as_slice(), &arr)?; } else { panic!("Not implemented"); } @@ -567,23 +589,27 @@ fn new_empty_dataset_helper( config: WriteConfig, ) -> Result> { let (datatype, fill) = match T::DTYPE { - ScalarType::U8 => (DataType::UInt8, 0u8.into()), - ScalarType::U16 => (DataType::UInt16, 0u16.into()), - ScalarType::U32 => (DataType::UInt32, 0u32.into()), - ScalarType::U64 => (DataType::UInt64, 0u64.into()), - ScalarType::I8 => (DataType::Int8, 0i8.into()), - ScalarType::I16 => (DataType::Int16, 0i16.into()), - ScalarType::I32 => (DataType::Int32, 0i32.into()), - ScalarType::I64 => (DataType::Int64, 0i64.into()), - ScalarType::F32 => (DataType::Float32, zarrs::array::ZARR_NAN_F32.into()), - ScalarType::F64 => (DataType::Float64, zarrs::array::ZARR_NAN_F64.into()), - ScalarType::Bool => (DataType::Bool, false.into()), - ScalarType::String => (DataType::String, "".into()), + ScalarType::U8 => (data_type::uint8(), FillValue::from(0u8)), + ScalarType::U16 => (data_type::uint16(), FillValue::from(0u16)), + ScalarType::U32 => (data_type::uint32(), FillValue::from(0u32)), + ScalarType::U64 => (data_type::uint64(), FillValue::from(0u64)), + ScalarType::I8 => (data_type::int8(), FillValue::from(0i8)), + ScalarType::I16 => (data_type::int16(), FillValue::from(0i16)), + ScalarType::I32 => (data_type::int32(), FillValue::from(0i32)), + ScalarType::I64 => (data_type::int64(), FillValue::from(0i64)), + ScalarType::F32 => (data_type::float32(), FillValue::from(0f32)), + ScalarType::F64 => (data_type::float64(), FillValue::from(0f64)), + ScalarType::Bool => (data_type::bool(), FillValue::from(false)), + ScalarType::String => (data_type::string(), FillValue::from("")), }; let shape = shape.as_ref(); let chunk_size: Vec = match config.block_size { - Some(s) => s.as_ref().into_iter().map(|x| (*x).max(1) as u64).collect(), + Some(s) => s + .as_ref() + .into_iter() + .map(|x| (*x).max(1) as u64) + .collect::>(), _ => { if shape.len() == 1 { vec![shape[0].min(16384).max(1) as u64] @@ -594,29 +620,35 @@ fn new_empty_dataset_helper( }; let mut use_sharding = true; - if matches!(datatype, DataType::String) {//|| shape.iter().sum::() == 0 { + if datatype == data_type::string() { + //|| shape.iter().sum::() == 0 { // Strings are not sharded, they are stored as a single chunk. use_sharding = false; } let array = if use_sharding { let shard_shape = chunk_size.iter().map(|&x| x * 8).collect::>(); - let mut sharding_codec_builder = - ShardingCodecBuilder::new(chunk_size.try_into()?); + let mut sharding_codec_builder = ShardingCodecBuilder::new( + chunk_size + .iter() + .map(|e| NonZeroU64::try_from(*e)) + .collect::, _>>()?, + &datatype, + ); sharding_codec_builder.bytes_to_bytes_codecs(vec![Arc::new(ZstdCodec::new(7, false))]); zarrs::array::ArrayBuilder::new( - shape.iter().map(|x| *x as u64).collect(), + shape.iter().map(|x| *x as u64).collect::>(), + shard_shape.as_slice(), datatype, - shard_shape.try_into()?, fill, ) .array_to_bytes_codec(sharding_codec_builder.build_arc()) .build(store, path)? } else { zarrs::array::ArrayBuilder::new( - shape.iter().map(|x| *x as u64).collect(), + shape.iter().map(|x| *x as u64).collect::>(), + chunk_size.as_slice(), datatype, - chunk_size.try_into()?, fill, ) .bytes_to_bytes_codecs(vec![Arc::new(ZstdCodec::new(7, false))]) @@ -710,7 +742,7 @@ mod tests { let mut dataset = group.new_empty_dataset::("test", &[20, 50].as_slice().into(), config)?; - let arr = Array::random((10, 10), Uniform::new(0, 100)); + let arr = Array::random((10, 10), Uniform::new(0, 100).unwrap()); dataset.write_array_slice(arr.view().into(), s![5..15, 10..20].as_ref())?; assert_eq!( arr, @@ -718,7 +750,7 @@ mod tests { ); // Repeatitive writes - let arr = Array::random((20, 50), Uniform::new(0, 100)); + let arr = Array::random((20, 50), Uniform::new(0, 100).unwrap()); dataset.write_array_slice(arr.view().into(), s![.., ..].as_ref())?; dataset.write_array_slice(arr.view().into(), s![.., ..].as_ref())?;