diff --git a/src/matchbox/client/locations.py b/src/matchbox/client/locations.py index dcd28931..11b3c9b2 100644 --- a/src/matchbox/client/locations.py +++ b/src/matchbox/client/locations.py @@ -152,6 +152,7 @@ def execute( rename: dict[str, str] | Callable | None = None, return_type: QueryReturnType = QueryReturnType.POLARS, keys: tuple[str, list[str]] | None = None, + limit: int | None = None, ) -> Iterator[QueryReturnClass]: """Execute ET logic against location and return batches. @@ -168,6 +169,7 @@ def execute( keys: Rule to only retrieve rows by specific keys. The key of the dictionary is a field name on which to filter. Filters source entries where the key field is in the dict values. + limit: Maximum number of rows to return. If None, returns all rows. Raises: AttributeError: If the cliet is not set. @@ -328,6 +330,7 @@ def execute( return_type: Literal[QueryReturnType.POLARS] = ..., keys: tuple[str, list[str]] | None = None, schema_overrides: dict[str, pl.DataType] | None = None, + limit: int | None = None, ) -> Generator[PolarsDataFrame, None, None]: ... @overload @@ -339,6 +342,7 @@ def execute( return_type: Literal[QueryReturnType.PANDAS] = ..., keys: tuple[str, list[str]] | None = None, schema_overrides: dict[str, pl.DataType] | None = None, + limit: int | None = None, ) -> Generator[PandasDataFrame, None, None]: ... @overload @@ -350,6 +354,7 @@ def execute( return_type: Literal[QueryReturnType.ARROW] = ..., keys: tuple[str, list[str]] | None = None, schema_overrides: dict[str, pl.DataType] | None = None, + limit: int | None = None, ) -> Generator[ArrowTable, None, None]: ... @requires_client @@ -361,6 +366,7 @@ def execute( # noqa: D102 return_type: QueryReturnType = QueryReturnType.POLARS, keys: tuple[str, list[str]] | None = None, schema_overrides: dict[str, pl.DataType] | None = None, + limit: int | None = None, ) -> Generator[QueryReturnClass, None, None]: # Strip semicolon, as it can block extended query protocol # and slow some engines down @@ -382,6 +388,11 @@ def execute( # noqa: D102 f"select * from ({extract_transform}) as sub " f"where {key_field} in ({comma_separated_values})" ) + # Add LIMIT clause if limit is specified + if limit is not None: + extract_transform = ( + f"select * from ({extract_transform}) as sub_limit limit {limit}" + ) yield from sql_to_df( stmt=extract_transform, schema_overrides=schema_overrides, diff --git a/src/matchbox/client/sources.py b/src/matchbox/client/sources.py index 4b4656c0..4a0dc68c 100644 --- a/src/matchbox/client/sources.py +++ b/src/matchbox/client/sources.py @@ -239,6 +239,7 @@ def fetch( batch_size: int | None = None, return_type: Literal[QueryReturnType.POLARS] = ..., keys: list[str] | None = None, + limit: int | None = None, ) -> Generator[PolarsDataFrame, None, None]: ... @overload @@ -248,6 +249,7 @@ def fetch( batch_size: int | None = None, return_type: Literal[QueryReturnType.PANDAS] = ..., keys: list[str] | None = None, + limit: int | None = None, ) -> Generator[PandasDataFrame, None, None]: ... @overload @@ -257,6 +259,7 @@ def fetch( batch_size: int | None = None, return_type: Literal[QueryReturnType.ARROW] = ..., keys: list[str] | None = None, + limit: int | None = None, ) -> Generator[ArrowTable, None, None]: ... def fetch( @@ -265,6 +268,7 @@ def fetch( batch_size: int | None = None, return_type: QueryReturnType = QueryReturnType.POLARS, keys: list[str] | None = None, + limit: int | None = None, ) -> Generator[QueryReturnClass, None, None]: """Applies the extract/transform logic to the source and returns batches lazily. @@ -274,6 +278,7 @@ def fetch( batch_size: Indicate the size of each batch when fetching data in batches. return_type: The type of data to return. Defaults to "polars". keys: List of keys to select a subset of all source entries. + limit: Maximum number of rows to return. If None, returns all rows. Returns: The requested data in the specified format, as an iterator of tables. @@ -295,6 +300,7 @@ def _rename(c: str) -> str: batch_size=batch_size, return_type=return_type, keys=(self.config.key_field.name, keys), + limit=limit, ) else: yield from self.location.execute( @@ -303,13 +309,30 @@ def _rename(c: str) -> str: rename=_rename, batch_size=batch_size, return_type=return_type, + limit=limit, ) def sample( self, n: int = 100, return_type: QueryReturnType = QueryReturnType.POLARS - ) -> None: - """Peek at the top n entries in a source.""" - return next(self.fetch(batch_size=n, return_type=return_type)) + ) -> QueryReturnClass: + """Peek at the top n entries in a source. + + Args: + n: The number of rows to return. Must be a positive integer. + return_type: The type of data to return. Defaults to "polars". + + Returns: + A dataframe containing the first n rows of the source. + + Raises: + TypeError: If n is not an integer. + ValueError: If n is not a positive integer. + """ + if not isinstance(n, int): + raise TypeError(f"n must be an integer, got {type(n).__name__}") + if n <= 0: + raise ValueError(f"n must be a positive integer, got {n}") + return next(self.fetch(limit=n, return_type=return_type)) @profile_time(attr="name") def run(self, batch_size: int | None = None) -> ArrowTable: diff --git a/test/client/test_locations.py b/test/client/test_locations.py index a2e191a6..07e57c12 100644 --- a/test/client/test_locations.py +++ b/test/client/test_locations.py @@ -303,6 +303,10 @@ def test_relational_db_execute( unfiltered_results = pl.concat(location.execute(sql, batch_size, keys=("key", []))) assert_frame_equal(unfiltered_results, combined_df) + # Try using limit to restrict results + limited_results = pl.concat(location.execute(sql, batch_size, limit=3)) + assert len(limited_results) == 3 + @pytest.mark.parametrize( "warehouse", diff --git a/test/client/test_sources.py b/test/client/test_sources.py index 0e02a5fa..d6ee4b56 100644 --- a/test/client/test_sources.py +++ b/test/client/test_sources.py @@ -155,6 +155,51 @@ def test_source_fetch_and_sample(sqla_sqlite_warehouse: Engine) -> None: assert len(result) == 5 +def test_source_sample_validation(sqla_sqlite_warehouse: Engine) -> None: + """Test that source.sample validates input correctly.""" + # Create test data + source_testkit = source_factory( + n_true_entities=5, + features=[ + {"name": "name", "base_generator": "word", "datatype": DataTypes.STRING}, + ], + engine=sqla_sqlite_warehouse, + ).write_to_location() + + # Create location and source + location = RelationalDBLocation(name="dbname").set_client(sqla_sqlite_warehouse) + source = Source( + dag=source_testkit.source.dag, + location=location, + name="test_source", + extract_transform=source_testkit.source_config.extract_transform, + infer_types=True, + key_field="key", + index_fields=["name"], + ) + + # Test that string n raises TypeError + with pytest.raises(TypeError, match="n must be an integer"): + source.sample(n="asdf") + + # Test that negative n raises ValueError + with pytest.raises(ValueError, match="n must be a positive integer"): + source.sample(n=-1) + + # Test that zero n raises ValueError + with pytest.raises(ValueError, match="n must be a positive integer"): + source.sample(n=0) + + # Test that float n raises TypeError + with pytest.raises(TypeError, match="n must be an integer"): + source.sample(n=10.5) + + # Test that valid positive integer works + result = source.sample(n=3) + assert isinstance(result, pl.DataFrame) + assert len(result) == 3 + + @pytest.mark.parametrize( "qualify_names", [