diff --git a/dedupe/datamodel.py b/dedupe/datamodel.py index bb4335658..5e2e2f13e 100644 --- a/dedupe/datamodel.py +++ b/dedupe/datamodel.py @@ -1,22 +1,17 @@ from __future__ import annotations import copyreg -import pkgutil +import importlib import types -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Type, cast import numpy -import dedupe.variables +from dedupe import variables from dedupe.variables.base import FieldType as FieldVariable from dedupe.variables.base import MissingDataType, Variable from dedupe.variables.interaction import InteractionType -for _, module, _ in pkgutil.iter_modules( # type: ignore - dedupe.variables.__path__, "dedupe.variables." -): - __import__(module) - if TYPE_CHECKING: from typing import Generator, Iterable, Sequence @@ -28,8 +23,6 @@ ) from dedupe.predicates import Predicate -VARIABLE_CLASSES = {k: v for k, v in FieldVariable.all_subclasses() if k} - class DataModel(object): version = 1 @@ -142,6 +135,26 @@ def __setstate__(self, d): self.__dict__ = d +def _get_variable_class(variable_type: str) -> Type[Variable]: + if ":" in variable_type: + module_name, class_name = variable_type.split(":") + else: + module_name, class_name = "dedupe.variables", variable_type + msg = ( + f"Bad variable type: '{variable_type}'. " + "Should be in the format mypackage.mymodule:MyClass " + f"or one of the built-in types: {variables.__all__}" + ) + try: + module = importlib.import_module(module_name) + except ImportError as e: + raise ValueError(msg) from e + try: + return getattr(module, class_name) + except AttributeError as e: + raise ValueError(msg) from e + + def typify_variables( variable_definitions: Iterable[VariableDefinition], ) -> tuple[list[FieldVariable], list[Variable]]: @@ -180,14 +193,7 @@ def typify_variables( if ("field" in d and d["field"] != definition["field"]) ] - try: - variable_class = VARIABLE_CLASSES[variable_type] - except KeyError: - raise KeyError( - "Field type %s not valid. Valid types include %s" - % (definition["type"], ", ".join(VARIABLE_CLASSES)) - ) - + variable_class = _get_variable_class(variable_type) variable_object = variable_class(definition) assert isinstance(variable_object, FieldVariable) diff --git a/dedupe/variables/__init__.py b/dedupe/variables/__init__.py index b36383a61..44a6dfe3e 100644 --- a/dedupe/variables/__init__.py +++ b/dedupe/variables/__init__.py @@ -1,3 +1,35 @@ from pkgutil import extend_path __path__ = extend_path(__path__, __name__) + +from dedupe.variables.base import CustomType as Custom + +# flake8: noqa +from dedupe.variables.categorical_type import CategoricalType as Categorical +from dedupe.variables.date_time import DateTimeType as DateTime +from dedupe.variables.exact import ExactType as Exact +from dedupe.variables.exists import ExistsType as Exists +from dedupe.variables.interaction import InteractionType as Interaction +from dedupe.variables.latlong import LatLongType as LatLong +from dedupe.variables.price import PriceType as Price +from dedupe.variables.set import SetType as Set +from dedupe.variables.string import ShortStringType as ShortString +from dedupe.variables.string import StringType as String +from dedupe.variables.string import TextType as Text + +__all__ = sorted( + [ + "Custom", + "Categorical", + "DateTime", + "Exact", + "Exists", + "Interaction", + "LatLong", + "Price", + "Set", + "ShortString", + "String", + "Text", + ] +) diff --git a/dedupe/variables/base.py b/dedupe/variables/base.py index f80b28faa..f0c82de64 100644 --- a/dedupe/variables/base.py +++ b/dedupe/variables/base.py @@ -5,7 +5,7 @@ from dedupe import predicates if TYPE_CHECKING: - from typing import Any, ClassVar, Generator, Iterable, Optional, Sequence, Type + from typing import Any, ClassVar, Iterable, Sequence, Type from dedupe._typing import Comparator, PredicateFunction, VariableDefinition @@ -47,15 +47,6 @@ def __getstate__(self) -> dict[str, Any]: return odict - @classmethod - def all_subclasses( - cls, - ) -> Generator[tuple[Optional[str], Type["Variable"]], None, None]: - for q in cls.__subclasses__(): - yield getattr(q, "type", None), q - for p in q.all_subclasses(): - yield p - class DerivedType(Variable): type = "Derived" diff --git a/tests/test_dedupe.py b/tests/test_dedupe.py index 8331fce69..3dda5ad6a 100644 --- a/tests/test_dedupe.py +++ b/tests/test_dedupe.py @@ -6,6 +6,9 @@ import numpy import dedupe +from dedupe import datamodel +from dedupe.datamodel import DataModel +from dedupe.variables.base import FieldType DATA = { 100: {"name": "Bob", "age": "50"}, @@ -31,8 +34,6 @@ class DataModelTest(unittest.TestCase): def test_data_model(self): - DataModel = dedupe.datamodel.DataModel - self.assertRaises(TypeError, DataModel) data_model = DataModel( @@ -75,6 +76,59 @@ def test_data_model(self): assert data_model._missing_field_indices == [] + def test_builtin_variables_basic(self): + """Tests that we can instantiate a DataModel with all of the + builtin variable types + """ + builtin_types = [ + "Exists", + "Exact", + "LatLong", + "Price", + "Set", + "ShortString", + "String", + "Text", + "DateTime", + ] + defs = [ + { + "field": "a", + "variable name": "a", + "type": t, + } + for t in builtin_types + ] + defs.append({"type": "Categorical", "field": "a", "categories": ["foo", "bar"]}) + defs.append({"type": "Custom", "field": "a", "comparator": lambda x, y: 0}) + DataModel(defs) + + def test_plugin_variables(self): + """Tests that we can instantiate a DataModel with a variable from a plugin""" + + class Mock(FieldType): + def __init__(self, definition): + super().__init__(definition) + + def make_defs(type_str): + return [ + { + "field": "a", + "variable name": "a", + "type": type_str, + } + ] + + # TODO: This has the side effect where datamodel now has an extra attribute. + # Really, we should be cleaning up after ourselves. + setattr(datamodel, "Mock", Mock) + + with self.assertRaises(ValueError): + DataModel(make_defs("Mock")) + + dm = DataModel(make_defs("dedupe.datamodel:Mock")) + assert isinstance(dm.primary_variables[0], Mock) + class ConnectedComponentsTest(unittest.TestCase): def test_components(self):