Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ repos:

# Mandatory internal hooks
- repo: https://github.com/uktrade/github-standards
rev: v1.2.1 # update periodically with pre-commit autoupdate
rev: v1.3.1 # update periodically with pre-commit autoupdate
hooks:
- id: run-security-scan
verbose: false
Expand Down
7 changes: 6 additions & 1 deletion src/matchbox/client/cli/eval/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,12 @@ def evaluate(
resolution = None
else:
# Get resolution name from --resolution or DAG's final_step
model = dag.get_model(resolution) if resolution is not None else dag.final_step
if resolution is not None:
model = dag.get_model(resolution)
elif dag.root is not None:
model = dag.root
else:
raise RuntimeError("No single root in DAG: a resolution name is needed.")
resolution = model.resolution_path

try:
Expand Down
33 changes: 16 additions & 17 deletions src/matchbox/client/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,23 +144,18 @@ def final_steps(self) -> list[Source | Model]:
return [self.nodes[name] for name in apex_node_names]

@property
def final_step(self) -> Source | Model:
"""Returns the root node in the DAG.
def root(self) -> Source | Model | None:
"""Looks for single root node in the DAG.

Returns:
The root node in the DAG

Raises:
ValueError: If the DAG does not have exactly one final step
The node corresponding to the root if there is one, otherwise None.
"""
steps = self.final_steps

if len(steps) == 0:
raise ValueError("No root node found, DAG might contain cycles")
elif len(steps) > 1:
raise ValueError("Some models or sources are unreachable")
else:
return steps[0]
if len(steps) != 1:
return None

return steps[0]

def source(self, *args: Any, **kwargs: Any) -> Source:
"""Create Source and add it to the DAG."""
Expand Down Expand Up @@ -538,9 +533,6 @@ def set_default(self) -> None:

Makes it immutable, then moves the default pointer to it.
"""
# Trigger error if there isn't a single root
_ = self.final_step

# tries to get apex, error if it doesn't exist
_handler.set_run_mutable(collection=self.name, run_id=self.run, mutable=False)
_handler.set_run_default(collection=self.name, run_id=self.run, default=True)
Expand All @@ -550,6 +542,7 @@ def lookup_key(
from_source: str,
to_sources: list[str],
key: str,
node: str | None = None,
threshold: float | None = None,
) -> dict[str, list[str]]:
"""Matches IDs against the selected backend.
Expand All @@ -558,6 +551,8 @@ def lookup_key(
from_source: Name of source the provided key belongs to
to_sources: Names of sources to find keys in
key: The value to match from the source. Usually a primary key
node: Name of node to use as point of truth. If None, will look for a
single root in the DAG.
threshold (optional): The threshold to use for creating clusters.
If None, uses the resolutions' default threshold
If a float, uses that threshold for the specified resolution, and the
Expand All @@ -582,14 +577,16 @@ def lookup_key(
if not isinstance(threshold, float):
raise ValueError("If passed, threshold must be a float")
threshold = threshold_float_to_int(threshold)
if not node and not self.root:
raise ValueError("No single root in DAG: a node to query from must be set.")
matches = _handler.match(
targets=[
ResolutionPath(name=target, collection=self.name, run=self.run)
for target in to_sources
],
source=ResolutionPath(name=from_source, collection=self.name, run=self.run),
key=key,
resolution=self.final_step.resolution_path,
resolution=node or self.root.resolution_path,
threshold=threshold,
)

Expand Down Expand Up @@ -621,7 +618,9 @@ def resolve(
if not isinstance(threshold, float):
raise ValueError("If passed, threshold must be a float")
threshold = threshold_float_to_int(threshold)
point_of_truth = self.nodes[node] if node else self.final_step
if not node and not self.root:
raise ValueError("No single root in DAG: a node to query from must be set.")
point_of_truth = self.nodes[node] if node else self.root

available_sources = {
node_name: self.get_source(node_name)
Expand Down
4 changes: 3 additions & 1 deletion src/matchbox/client/eval/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def _get_samples_from_server(
if resolution:
resolution_path: ModelResolutionPath = dag.get_model(resolution).resolution_path
else:
resolution_path: ModelResolutionPath = dag.final_step.resolution_path
if not dag.root:
raise ValueError("Must set a resolution if DAG does not have single root.")
resolution_path: ModelResolutionPath = dag.root.resolution_path
return pl.from_arrow(
_handler.sample_for_eval(n=n, resolution=resolution_path, user_name=user_name)
)
Expand Down
92 changes: 92 additions & 0 deletions src/matchbox/client/mappings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Interface to external record mapping."""

from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar

from pyarrow import Table as ArrowTable

from matchbox.common.dtos import ResolutionPath
from matchbox.common.logging import profile_time

if TYPE_CHECKING:
from matchbox.client.dags import DAG
from matchbox.client.locations import Location
else:
DAG = Any
Location = Any

T = TypeVar("T")


def post_run(method: Callable[..., T]) -> Callable[..., T]:
"""Decorator to ensure that a method is called after mapping run.

Raises:
RuntimeError: If run hasn't happened.
"""

@wraps(method)
def wrapper(self: "Mapping", *args: Any, **kwargs: Any) -> T:
if self.data is None:
raise RuntimeError(
"The mapping must be run before attempting this operation."
)
return method(self, *args, **kwargs)

return wrapper


class Mapping:
"""Links encoded by external data."""

def __init__(
self,
dag: DAG,
location: Location,
name: str,
extract_transform: str,
description: str | None = None,
) -> None:
"""Initialise mapping node."""
self.dag = dag
self.location = location
self.name = name
self.extract_transform = extract_transform
self.description = description

TODO: TypeAlias = None

@property
def config(self) -> TODO:
"""Generate MappingConfig from Mapping."""
raise NotImplementedError

def to_resolution(self) -> None:
"""Generate MappingConfig from Mapping."""
raise NotImplementedError

def from_resolution(self) -> None:
"""Generate MappingConfig from Mapping."""
raise NotImplementedError

@property
def resolution_path(self) -> ResolutionPath:
"""Returns the resolution path."""
return ResolutionPath(
collection=self.dag.name, run=self.dag.run, name=self.name
)

@profile_time(attr="name")
def run(self, batch_size: int | None = None) -> ArrowTable:
"""."""
pass

@post_run
@profile_time(attr="name")
def sync(self) -> None:
"""."""

def clear_data(self) -> None:
"""Deletes data computed for node."""
self.data = None
2 changes: 1 addition & 1 deletion src/matchbox/client/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
description: str | None = None,
infer_types: bool = False,
validate_etl: bool = True,
):
) -> None:
"""Initialise source.

Args:
Expand Down
25 changes: 1 addition & 24 deletions test/client/test_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@ def test_dag_set_client(sqla_sqlite_warehouse: Engine) -> None:
assert dag.get_source("bar").location.client == sqla_sqlite_warehouse


def test_dag_set_default_ok(
def test_dag_set_default(
matchbox_api: MockRouter,
sqla_sqlite_warehouse: Engine,
) -> None:
Expand Down Expand Up @@ -1163,26 +1163,3 @@ def test_dag_set_default_ok(
# Verify both endpoints were called
assert api_mutable.called
assert api_default.called


def test_dag_set_default_not_connected() -> None:
"""Set default raises error when DAG is not connected to server."""
dag = DAG(name="test_collection")
dag.source(**source_factory().into_dag())

with pytest.raises(RuntimeError, match="has not been connected"):
dag.set_default()


def test_dag_set_default_unreachable_nodes(sqla_sqlite_warehouse: Engine) -> None:
"""Nodes cannot be unreachable from root when setting a default run."""
dag = TestkitDAG().dag

foo_tkit = source_factory(name="foo", engine=sqla_sqlite_warehouse)
bar_tkit = source_factory(name="bar", engine=sqla_sqlite_warehouse)

dag.source(**foo_tkit.into_dag())
dag.source(**bar_tkit.into_dag())

with pytest.raises(ValueError, match="unreachable"):
dag.set_default()
11 changes: 4 additions & 7 deletions test/client/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,7 @@ def test_get_samples_remote(
# Check results - test with samples that include all three sources
# All three sources (foo, bar, baz) are in loaded_dag with the warehouse location
samples_all = get_samples(
n=10,
resolution=dag.final_step.resolution_path.name,
user_name=user.user_name,
dag=loaded_dag,
n=10, resolution=dag.root.name, user_name=user.user_name, dag=loaded_dag
)

assert sorted(samples_all.keys()) == [10, 11]
Expand Down Expand Up @@ -252,7 +249,7 @@ def test_get_samples_remote(

samples = get_samples(
n=10,
resolution=dag.final_step.resolution_path.name,
resolution=dag.root.name,
user_name=user.user_name,
dag=loaded_dag,
)
Expand Down Expand Up @@ -303,7 +300,7 @@ def test_get_samples_remote(

no_samples = get_samples(
n=10,
resolution=dag.final_step.resolution_path.name,
resolution=dag.root.name,
user_name=user.user_name,
dag=loaded_dag,
)
Expand All @@ -322,7 +319,7 @@ def test_get_samples_remote(
with pytest.raises(MatchboxSourceTableError, match="Could not query source"):
get_samples(
n=10,
resolution=dag.final_step.resolution_path.name,
resolution=dag.root.name,
user_name=user.user_name,
dag=bad_dag,
)
7 changes: 2 additions & 5 deletions test/common/factories/test_testkit_dag.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Unit tests for TestkitDAG."""

import pytest

from matchbox.common.factories.dags import TestkitDAG
from matchbox.common.factories.entities import FeatureConfig
from matchbox.common.factories.models import model_factory
Expand Down Expand Up @@ -177,6 +175,5 @@ def test_empty_dag_properties() -> None:
# DAG should be empty but valid
assert len(dag_testkit.dag.nodes) == 0

# Final_step should raise on empty DAG
with pytest.raises(ValueError):
_ = dag_testkit.dag.final_step
# final_step should return nothing
assert dag_testkit.dag.root is None
2 changes: 1 addition & 1 deletion test/e2e/test_e2e_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ async def test_evaluation_workflow_server(self) -> None:

# Create app and verify it can load samples from real data
app = EntityResolutionApp(
resolution=dag.final_step.resolution_path,
resolution=dag.root.resolution_path,
num_samples=2,
session_tag="eval_session1",
user="alice",
Expand Down
Loading