Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,20 @@ To run the server, see the [server installation documentation](https://uktrade.g
## Development

See our full development guide and coding standards on our [contribution guide](https://uktrade.github.io/matchbox/contributing/).

## Local development with Datadog

When iterating the Datadog configuration, environment variables can be set in several ways:

1. **Datadog configuration**: Create a `.datadog.env` file with your Datadog API key and other agent settings
2. **Compose override**: Use `docker-compose.override.yml` for local-specific variable overrides

Variables in `.datadog.env` will override any defaults set in the compose file.

Example `.datadog.env`:

```
DD_API_KEY=your_api_key_here
```

The Docker Compose file will automatically set `DD_ENV=local-{username}` for local development isolation.
4 changes: 2 additions & 2 deletions environments/development.env
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ MB__DEV__API_PORT=8000
MB__DEV__DATASTORE_CONSOLE_PORT=9003
MB__DEV__DATASTORE_PORT=9002
MB__DEV__WAREHOUSE_PORT=7654
MB__DEV__POSTGRES_BACKEND_PORT=5432 # Change to 9876 here and server.env to avoid conflict with other services
MB__DEV__POSTGRES_BACKEND_PORT=9876 # Change to 9876 here and server.env to avoid conflict with other services

MB__SERVER__API_KEY=matchbox-api-key
MB__SERVER__BACKEND_TYPE=postgres
Expand All @@ -17,7 +17,7 @@ MB__SERVER__DATASTORE__DEFAULT_REGION=eu-west-2
MB__SERVER__DATASTORE__CACHE_BUCKET_NAME=cache

MB__SERVER__POSTGRES__HOST=localhost
MB__SERVER__POSTGRES__PORT=5432 # Change to 9876 here and server.env to avoid conflict with other services
MB__SERVER__POSTGRES__PORT=9876 # Change to 9876 here and server.env to avoid conflict with other services
MB__SERVER__POSTGRES__USER=matchbox_user
MB__SERVER__POSTGRES__PASSWORD=matchbox_password
MB__SERVER__POSTGRES__DATABASE=matchbox
Expand Down
2 changes: 1 addition & 1 deletion environments/server.env
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ MB__DEV__API_PORT=8000
MB__DEV__DATASTORE_CONSOLE_PORT=9003
MB__DEV__DATASTORE_PORT=9002
MB__DEV__WAREHOUSE_PORT=7654
MB__DEV__POSTGRES_BACKEND_PORT=5432 # Change to 9876 here and development.env to avoid conflict with other services
MB__DEV__POSTGRES_BACKEND_PORT=9876 # Change to 9876 here and development.env to avoid conflict with other services

MB__SERVER__API_KEY=matchbox-api-key
MB__SERVER__BACKEND_TYPE=postgres
Expand Down
11 changes: 6 additions & 5 deletions src/matchbox/client/_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,21 +158,21 @@ def login(user_name: str) -> int:


def query(
source: SourceResolutionName,
sources: list[SourceResolutionName],
return_leaf_id: bool,
resolution: ResolutionName | None = None,
threshold: int | None = None,
limit: int | None = None,
) -> Table:
"""Query a source in Matchbox."""
log_prefix = f"Query {source}"
"""Query multiple sources in Matchbox."""
log_prefix = f"Query {', '.join(sources)}"
logger.debug(f"Using {resolution}", prefix=log_prefix)

res = CLIENT.get(
"/query",
params=url_params(
{
"source": source,
"sources": sources,
"resolution": resolution,
"return_leaf_id": return_leaf_id,
"threshold": threshold,
Expand Down Expand Up @@ -438,7 +438,8 @@ def sample_for_eval(n: int, resolution: ModelResolutionName, user_id: int) -> Ta
params=url_params({"n": n, "resolution": resolution, "user_id": user_id}),
)

return read_table(BytesIO(res.content))
buffer = BytesIO(res.content)
return read_table(buffer)


def compare_models(resolutions: list[ModelResolutionName]) -> ModelComparison:
Expand Down
29 changes: 24 additions & 5 deletions src/matchbox/client/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,44 @@ def key_field_map(
source_mb_ids: list[ArrowTable] = []
source_to_key_field: dict[str, str] = {}

# Store source names and key field mappings
source_names = [s.name for s in sources]
for s in sources:
# Get Matchbox IDs from backend
source_to_key_field[s.name] = s.key_field.name

if len(sources) == 1:
# Single source - make individual call
source_mb_ids.append(
_handler.query(
source=s.name,
sources=[sources[0].name],
resolution=resolution,
return_leaf_id=False,
)
)
else:
# Multiple sources - make single multi-source call
combined_result = _handler.query(
sources=source_names,
resolution=resolution,
return_leaf_id=False,
)

source_to_key_field[s.name] = s.key_field.name
# Split the combined result by source
import polars as pl

combined_df = pl.from_arrow(combined_result)
for source_name in source_names:
source_data = combined_df.filter(pl.col("source") == source_name).to_arrow()
source_mb_ids.append(source_data)

# Join Matchbox IDs to form mapping table
mapping = source_mb_ids[0]
mapping = source_mb_ids[0].select(["id", "key"])
mapping = mapping.rename_columns({"key": sources[0].qualified_key})
if len(sources) > 1:
for s, mb_ids in zip(sources[1:], source_mb_ids[1:], strict=True):
mb_ids_selected = mb_ids.select(["id", "key"])
mapping = mapping.join(
right_table=mb_ids, keys="id", join_type="full outer"
right_table=mb_ids_selected, keys="id", join_type="full outer"
)
mapping = mapping.rename_columns({"key": s.qualified_key})

Expand Down
37 changes: 25 additions & 12 deletions src/matchbox/client/helpers/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,27 @@ def _process_selectors(

For batched queries, yield from it.
"""
selector_results: list[PolarsDataFrame] = []
for selector in selectors:
mb_ids = pl.from_arrow(
_handler.query(
source=selector.source.name,
resolution=resolution,
threshold=threshold,
return_leaf_id=return_leaf_id,
)
# Group selectors by resolution to make efficient multi-source queries
if not selectors:
return

# Make single multi-source query with all selectors
source_names = [selector.source.name for selector in selectors]

# Make single multi-source API call
mb_ids = pl.from_arrow(
_handler.query(
sources=source_names,
resolution=resolution,
threshold=threshold,
return_leaf_id=return_leaf_id,
)
)

# Process each selector with the multi-source result
for selector in selectors:
# Filter the multi-source results to this selector's source
source_filtered_ids = mb_ids.filter(pl.col("source") == selector.source.name)

raw_batches = selector.source.query(
qualify_names=True,
Expand All @@ -218,15 +229,17 @@ def _process_selectors(
processed_batches = [
_process_query_result(
data=b,
mb_ids=source_filtered_ids,
selector=selector,
mb_ids=mb_ids,
return_leaf_id=return_leaf_id,
)
for b in raw_batches
]
selector_results.append(pl.concat(processed_batches, how="vertical"))

return selector_results
# Concatenate all batches for this selector and yield
if processed_batches:
selector_result = pl.concat(processed_batches, how="vertical")
yield selector_result


def query(
Expand Down
21 changes: 15 additions & 6 deletions src/matchbox/common/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@
from typing import Final

import pyarrow as pa
import pyarrow.parquet as pq
from pyarrow import Schema

from matchbox.common.exceptions import MatchboxArrowSchemaMismatch

SCHEMA_QUERY: Final[pa.Schema] = pa.schema(
[("id", pa.int64()), ("key", pa.large_string())]
[
("id", pa.int64()),
("key", pa.large_string()),
("source", pa.dictionary(pa.int32(), pa.string())),
]
)
"""Data transfer schema for root cluster IDs keyed to primary keys."""
"""Data transfer schema for root cluster IDs keyed to primary keys with source ID."""

SCHEMA_QUERY_WITH_LEAVES = SCHEMA_QUERY.append(pa.field("leaf_id", pa.int64()))
"""Data transfer schema for root cluster IDs keyed to primary keys and leaf IDs."""
"""Data transfer schema for cluster IDs with primary keys, source ID, and leaf IDs."""


SCHEMA_INDEX: Final[pa.Schema] = pa.schema(
Expand Down Expand Up @@ -70,9 +73,15 @@ class JudgementsZipFilenames(StrEnum):


def table_to_buffer(table: pa.Table) -> BytesIO:
"""Converts an Arrow table to a BytesIO buffer."""
"""Converts an Arrow table to a BytesIO buffer using Arrow IPC format.

Uses Arrow IPC format instead of parquet to preserve exact schema fidelity,
including uint32 dictionary indices and large_string values.
"""
sink = BytesIO()
pq.write_table(table, sink)
writer = pa.ipc.new_file(sink, table.schema)
writer.write_table(table)
writer.close()
sink.seek(0)
return sink

Expand Down
6 changes: 5 additions & 1 deletion src/matchbox/common/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,11 @@ def process_judgements(
# if missing expansion, assume we're dealing with singleton leaves
.with_columns(
pl.when(pl.col("endorsed_leaves").is_null())
.then(pl.col("endorsed").map_elements(lambda x: [x]))
.then(
pl.col("endorsed").map_elements(
lambda x: [x], return_dtype=pl.List(pl.UInt64)
)
)
.otherwise(pl.col("endorsed_leaves"))
.alias("endorsed_leaves")
)
Expand Down
48 changes: 29 additions & 19 deletions src/matchbox/common/factories/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,12 @@ def get_values(
raise ValueError(f"SourceConfig not found: {source_name}")

# Get rows for this entity in this source
df = source.data.to_pandas()
entity_rows = df[df["key"].isin(keys)]
df = pl.from_arrow(source.data)
entity_rows = df.filter(pl.col("key").is_in(keys))

# Get unique values for each feature in this source
values[source_name] = {
feature.name: sorted(entity_rows[feature.name].unique())
feature.name: sorted(entity_rows[feature.name].unique().to_list())
for feature in source.features
}

Expand Down Expand Up @@ -473,33 +473,43 @@ def query_to_cluster_entities(
Returns:
A set of ClusterEntity objects
"""
# Convert polars to pandas for compatibility with existing logic
if isinstance(query, pl.DataFrame):
query = query.to_pandas()
elif isinstance(query, pa.Table):
query = query.to_pandas()
# Convert to polars for efficient processing (avoids pandas uint32 issues)
if isinstance(query, pa.Table):
query_df = pl.from_arrow(query)
elif isinstance(query, pd.DataFrame):
query_df = pl.from_pandas(query)
else:
query_df = query

must_have_fields = set(["id"] + list(keys.values()))
if not must_have_fields.issubset(query.columns):
if not must_have_fields.issubset(query_df.columns):
raise ValueError(
f"Fields {must_have_fields.difference(query.columns)} must be included "
f"Fields {must_have_fields.difference(query_df.columns)} must be included "
"in the query and are missing."
)

def _create_cluster_entity(group: pd.DataFrame) -> ClusterEntity:
entity_refs = {
source: frozenset(group[key_field].dropna().values)
for source, key_field in keys.items()
if not group[key_field].dropna().empty
}
def _create_cluster_entity(group_df: pl.DataFrame) -> ClusterEntity:
# Get the cluster ID (should be the same for all rows in the group)
cluster_id = group_df["id"][0]

entity_refs = {}
for source, key_field in keys.items():
# Get non-null values for this key field
values = group_df.filter(pl.col(key_field).is_not_null())[key_field]
if len(values) > 0:
entity_refs[source] = frozenset(values.to_list())

return ClusterEntity(
id=group.name,
id=cluster_id,
keys=EntityReference(entity_refs),
)

result = query.groupby("id").apply(_create_cluster_entity, include_groups=False)
return set(result.tolist())
# Group by cluster ID and create ClusterEntity for each group
result = []
for _cluster_id, group_df in query_df.group_by("id"):
result.append(_create_cluster_entity(group_df))

return set(result)


@cache
Expand Down
6 changes: 3 additions & 3 deletions src/matchbox/server/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,16 @@ async def get_upload_status(
)
def query(
backend: BackendDependency,
source: SourceResolutionName,
sources: Annotated[list[SourceResolutionName], Query()],
return_leaf_id: bool,
resolution: ResolutionName | None = None,
threshold: int | None = None,
limit: int | None = None,
) -> ParquetResponse:
"""Query Matchbox for matches based on a source resolution name."""
"""Query Matchbox for matches based on multiple source resolution names."""
try:
res = backend.query(
source=source,
sources=sources,
resolution=resolution,
threshold=threshold,
return_leaf_id=return_leaf_id,
Expand Down
6 changes: 3 additions & 3 deletions src/matchbox/server/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ class MatchboxDBAdapter(ABC):
@abstractmethod
def query(
self,
source: SourceResolutionName,
sources: list[SourceResolutionName],
resolution: ResolutionName | None = None,
threshold: int | None = None,
return_leaf_id: bool = False,
Expand All @@ -237,9 +237,9 @@ def query(
"""Queries the database from an optional point of truth.

Args:
source: the `SourceResolutionName` string identifying the source to query
sources: list of `SourceResolutionName` strings identifying sources to query
resolution (optional): the resolution to use for filtering results
If not specified, will use the source resolution for the queried source
If not specified, will use the source resolution for the first source
threshold (optional): the threshold to use for creating clusters
If None, uses the models' default threshold
If an integer, uses that threshold for the specified model, and the
Expand Down
4 changes: 2 additions & 2 deletions src/matchbox/server/postgresql/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,14 @@ def __init__(self, settings: MatchboxPostgresSettings):

def query( # noqa: D102
self,
source: SourceResolutionName,
sources: list[SourceResolutionName],
resolution: ResolutionName | None = None,
threshold: int | None = None,
return_leaf_id: bool = False,
limit: int | None = None,
) -> ArrowTable:
return query(
source=source,
sources=sources,
resolution=resolution,
threshold=threshold,
return_leaf_id=return_leaf_id,
Expand Down
Loading
Loading