Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/matchbox/client/locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def execute(
rename: dict[str, str] | Callable | None = None,
return_type: QueryReturnType = QueryReturnType.POLARS,
keys: tuple[str, list[str]] | None = None,
limit: int | None = None,
) -> Iterator[QueryReturnClass]:
"""Execute ET logic against location and return batches.

Expand All @@ -168,6 +169,7 @@ def execute(
keys: Rule to only retrieve rows by specific keys.
The key of the dictionary is a field name on which to filter.
Filters source entries where the key field is in the dict values.
limit: Maximum number of rows to return. If None, returns all rows.

Raises:
AttributeError: If the cliet is not set.
Expand Down Expand Up @@ -328,6 +330,7 @@ def execute(
return_type: Literal[QueryReturnType.POLARS] = ...,
keys: tuple[str, list[str]] | None = None,
schema_overrides: dict[str, pl.DataType] | None = None,
limit: int | None = None,
) -> Generator[PolarsDataFrame, None, None]: ...

@overload
Expand All @@ -339,6 +342,7 @@ def execute(
return_type: Literal[QueryReturnType.PANDAS] = ...,
keys: tuple[str, list[str]] | None = None,
schema_overrides: dict[str, pl.DataType] | None = None,
limit: int | None = None,
) -> Generator[PandasDataFrame, None, None]: ...

@overload
Expand All @@ -350,6 +354,7 @@ def execute(
return_type: Literal[QueryReturnType.ARROW] = ...,
keys: tuple[str, list[str]] | None = None,
schema_overrides: dict[str, pl.DataType] | None = None,
limit: int | None = None,
) -> Generator[ArrowTable, None, None]: ...

@requires_client
Expand All @@ -361,6 +366,7 @@ def execute( # noqa: D102
return_type: QueryReturnType = QueryReturnType.POLARS,
keys: tuple[str, list[str]] | None = None,
schema_overrides: dict[str, pl.DataType] | None = None,
limit: int | None = None,
) -> Generator[QueryReturnClass, None, None]:
# Strip semicolon, as it can block extended query protocol
# and slow some engines down
Expand All @@ -382,6 +388,11 @@ def execute( # noqa: D102
f"select * from ({extract_transform}) as sub "
f"where {key_field} in ({comma_separated_values})"
)
# Add LIMIT clause if limit is specified
if limit is not None:
extract_transform = (
f"select * from ({extract_transform}) as sub_limit limit {limit}"
)
yield from sql_to_df(
stmt=extract_transform,
schema_overrides=schema_overrides,
Expand Down
29 changes: 26 additions & 3 deletions src/matchbox/client/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def fetch(
batch_size: int | None = None,
return_type: Literal[QueryReturnType.POLARS] = ...,
keys: list[str] | None = None,
limit: int | None = None,
) -> Generator[PolarsDataFrame, None, None]: ...

@overload
Expand All @@ -248,6 +249,7 @@ def fetch(
batch_size: int | None = None,
return_type: Literal[QueryReturnType.PANDAS] = ...,
keys: list[str] | None = None,
limit: int | None = None,
) -> Generator[PandasDataFrame, None, None]: ...

@overload
Expand All @@ -257,6 +259,7 @@ def fetch(
batch_size: int | None = None,
return_type: Literal[QueryReturnType.ARROW] = ...,
keys: list[str] | None = None,
limit: int | None = None,
) -> Generator[ArrowTable, None, None]: ...

def fetch(
Expand All @@ -265,6 +268,7 @@ def fetch(
batch_size: int | None = None,
return_type: QueryReturnType = QueryReturnType.POLARS,
keys: list[str] | None = None,
limit: int | None = None,
) -> Generator[QueryReturnClass, None, None]:
"""Applies the extract/transform logic to the source and returns batches lazily.

Expand All @@ -274,6 +278,7 @@ def fetch(
batch_size: Indicate the size of each batch when fetching data in batches.
return_type: The type of data to return. Defaults to "polars".
keys: List of keys to select a subset of all source entries.
limit: Maximum number of rows to return. If None, returns all rows.

Returns:
The requested data in the specified format, as an iterator of tables.
Expand All @@ -295,6 +300,7 @@ def _rename(c: str) -> str:
batch_size=batch_size,
return_type=return_type,
keys=(self.config.key_field.name, keys),
limit=limit,
)
else:
yield from self.location.execute(
Expand All @@ -303,13 +309,30 @@ def _rename(c: str) -> str:
rename=_rename,
batch_size=batch_size,
return_type=return_type,
limit=limit,
)

def sample(
self, n: int = 100, return_type: QueryReturnType = QueryReturnType.POLARS
) -> None:
"""Peek at the top n entries in a source."""
return next(self.fetch(batch_size=n, return_type=return_type))
) -> QueryReturnClass:
"""Peek at the top n entries in a source.

Args:
n: The number of rows to return. Must be a positive integer.
return_type: The type of data to return. Defaults to "polars".

Returns:
A dataframe containing the first n rows of the source.

Raises:
TypeError: If n is not an integer.
ValueError: If n is not a positive integer.
"""
if not isinstance(n, int):
raise TypeError(f"n must be an integer, got {type(n).__name__}")
if n <= 0:
raise ValueError(f"n must be a positive integer, got {n}")
return next(self.fetch(limit=n, return_type=return_type))

@profile_time(attr="name")
def run(self, batch_size: int | None = None) -> ArrowTable:
Expand Down
4 changes: 4 additions & 0 deletions test/client/test_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ def test_relational_db_execute(
unfiltered_results = pl.concat(location.execute(sql, batch_size, keys=("key", [])))
assert_frame_equal(unfiltered_results, combined_df)

# Try using limit to restrict results
limited_results = pl.concat(location.execute(sql, batch_size, limit=3))
assert len(limited_results) == 3


@pytest.mark.parametrize(
"warehouse",
Expand Down
45 changes: 45 additions & 0 deletions test/client/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,51 @@ def test_source_fetch_and_sample(sqla_sqlite_warehouse: Engine) -> None:
assert len(result) == 5


def test_source_sample_validation(sqla_sqlite_warehouse: Engine) -> None:
"""Test that source.sample validates input correctly."""
# Create test data
source_testkit = source_factory(
n_true_entities=5,
features=[
{"name": "name", "base_generator": "word", "datatype": DataTypes.STRING},
],
engine=sqla_sqlite_warehouse,
).write_to_location()

# Create location and source
location = RelationalDBLocation(name="dbname").set_client(sqla_sqlite_warehouse)
source = Source(
dag=source_testkit.source.dag,
location=location,
name="test_source",
extract_transform=source_testkit.source_config.extract_transform,
infer_types=True,
key_field="key",
index_fields=["name"],
)

# Test that string n raises TypeError
with pytest.raises(TypeError, match="n must be an integer"):
source.sample(n="asdf")

# Test that negative n raises ValueError
with pytest.raises(ValueError, match="n must be a positive integer"):
source.sample(n=-1)

# Test that zero n raises ValueError
with pytest.raises(ValueError, match="n must be a positive integer"):
source.sample(n=0)

# Test that float n raises TypeError
with pytest.raises(TypeError, match="n must be an integer"):
source.sample(n=10.5)

# Test that valid positive integer works
result = source.sample(n=3)
assert isinstance(result, pl.DataFrame)
assert len(result) == 3


@pytest.mark.parametrize(
"qualify_names",
[
Expand Down
Loading