From 8851ed6c3128acf692e79f53d02f617c6be5f94a Mon Sep 17 00:00:00 2001 From: Jerome Dockes Date: Thu, 2 Apr 2026 18:08:42 +0200 Subject: [PATCH 1/2] add caching to .skb.apply --- skrub/_config.py | 13 +++++++ skrub/_data_ops/_cached_helpers.py | 19 ++++++++++ skrub/_data_ops/_data_ops.py | 58 +++++++++++++++++++++++++----- skrub/tests/test_config.py | 2 ++ 4 files changed, 84 insertions(+), 8 deletions(-) create mode 100644 skrub/_data_ops/_cached_helpers.py diff --git a/skrub/_config.py b/skrub/_config.py index f3bd185f8..26e63c29a 100644 --- a/skrub/_config.py +++ b/skrub/_config.py @@ -5,6 +5,7 @@ from contextlib import contextmanager from pathlib import Path +import joblib import numpy as np @@ -40,6 +41,12 @@ def _get_default_data_dir(): return str(data_home) +def _get_default_cache_dir(): + cache_dir = Path(_get_default_data_dir()) / "_cache" + cache_dir.mkdir(exist_ok=True) + return str(cache_dir) + + def _parse_env_bool(env_variable_name, default): value = os.getenv(env_variable_name, default) if isinstance(value, bool): @@ -64,6 +71,8 @@ def _parse_env_bool(env_variable_name, default): "float_precision": int(os.environ.get("SKB_FLOAT_PRECISION", 3)), "cardinality_threshold": int(os.environ.get("SKB_CARDINALITY_THRESHOLD", 40)), "data_dir": _get_default_data_dir(), + "cache_dir": _get_default_cache_dir(), + "memory": joblib.Memory(_get_default_cache_dir(), verbose=0), "eager_data_ops": _parse_env_bool("SKB_EAGER_DATA_OPS", True), } _threadlocal = threading.local() @@ -113,6 +122,8 @@ def set_config( float_precision=None, cardinality_threshold=None, data_dir=None, + cache_dir=None, + memory=None, eager_data_ops=None, ): """Set global skrub configuration. @@ -314,6 +325,8 @@ def config_context( float_precision=None, cardinality_threshold=None, data_dir=None, + cache_dir=None, + memory=None, eager_data_ops=None, ): """Context manager for global skrub configuration. diff --git a/skrub/_data_ops/_cached_helpers.py b/skrub/_data_ops/_cached_helpers.py new file mode 100644 index 000000000..082c46ac3 --- /dev/null +++ b/skrub/_data_ops/_cached_helpers.py @@ -0,0 +1,19 @@ +""" +functions meant to be cached with joblib. + +They are in their own module so the cache is less likely to be invalidated due +to the line number of the function definition changing. +""" + +import joblib + + +def _call_fitting_method(estimator, method_name, args, kwargs): + # we could also just generate a str(uuid.uuid4()) 🤔 + estimator_id = joblib.hash((estimator, method_name, args, kwargs)) + result = getattr(estimator, method_name)(*args, **kwargs) + return estimator, result, estimator_id + + +def _call_non_fitting_method(estimator, method_name, args, kwargs, estimator_id): + return getattr(estimator, method_name)(*args, **kwargs) diff --git a/skrub/_data_ops/_data_ops.py b/skrub/_data_ops/_data_ops.py index 80fbd10da..906158340 100644 --- a/skrub/_data_ops/_data_ops.py +++ b/skrub/_data_ops/_data_ops.py @@ -33,6 +33,7 @@ import itertools import operator import pathlib +import pickle import re import textwrap import traceback @@ -49,7 +50,7 @@ from .._reporting._utils import strip_xml_declaration from .._utils import PassThrough, set_module, short_repr from .._wrap_transformer import wrap_transformer -from . import _utils +from . import _cached_helpers, _utils from ._choosing import get_chosen_or_default from ._utils import FITTED_PREDICTOR_METHODS, NULL, attribute_error @@ -1325,6 +1326,36 @@ def check_subsampled_X_y_shape(X_op, y_op, X_value, y_value, mode, environment, ) +def _call_fitting_method(estimator, method_name, args, kwargs): + memory = _config.get_config()["memory"] + if memory is None: + result = getattr(estimator, method_name)(*args, **kwargs) + return estimator, result, None + try: + return memory.cache(_cached_helpers._call_fitting_method)( + estimator, method_name, args, kwargs + ) + except pickle.PicklingError: + pass + # Fall back to non-cached call if arguments cannot be serialized + result = getattr(estimator, method_name)(*args, **kwargs) + return estimator, result, None + + +def _call_non_fitting_method(estimator, method_name, args, kwargs, estimator_id): + memory = _config.get_config()["memory"] + if memory is None or estimator_id is None: + return getattr(estimator, method_name)(*args, **kwargs) + try: + return memory.cache( + _cached_helpers._call_non_fitting_method, ignore=["estimator"] + )(estimator, method_name, args, kwargs, estimator_id) + except pickle.PicklingError: + pass + # Fall back to non-cached call if arguments cannot be serialized + return getattr(estimator, method_name)(*args, **kwargs) + + class Apply(DataOpImpl): """.skb.apply() nodes.""" @@ -1389,9 +1420,13 @@ def eval(self, *, mode, environment): # with `.predict()` if method_name == "fit_transform": fit_kwargs = yield from self._eval_kwargs("fit") - self.estimator_.fit(X, y, **fit_kwargs) + self.estimator_, _, self.estimator_id_ = _call_fitting_method( + self.estimator_, "fit", (X, y), fit_kwargs + ) predict_kwargs = yield from self._eval_kwargs("predict") - pred = self.estimator_.predict(X, **predict_kwargs) + pred = _call_non_fitting_method( + self.estimator_, "predict", (X,), predict_kwargs, self.estimator_id_ + ) # In `(fit_)transform` mode only, format the predictions as a # dataframe or column if y was one during `fit()` return self._format_predictions(X, pred) @@ -1402,13 +1437,20 @@ def eval(self, *, mode, environment): method_name = "fit_transform" if "fit" in method_name: - y_arg = () if self.unsupervised else (y,) + args = (X,) if self.unsupervised else (X, y) elif method_name == "score": - y_arg = (y,) + args = (X, y) else: - y_arg = () - method_kwargs = yield from self._eval_kwargs(method_name) - return getattr(self.estimator_, method_name)(X, *y_arg, **method_kwargs) + args = (X,) + kwargs = yield from self._eval_kwargs(method_name) + if "fit" in method_name: + self.estimator_, result, self.estimator_id_ = _call_fitting_method( + self.estimator_, method_name, args, kwargs + ) + return result + return _call_non_fitting_method( + self.estimator_, method_name, args, kwargs, self.estimator_id_ + ) def _store_y_format(self, y): if sbd.is_dataframe(y): diff --git a/skrub/tests/test_config.py b/skrub/tests/test_config.py index 40fbbe60d..e071c3cba 100644 --- a/skrub/tests/test_config.py +++ b/skrub/tests/test_config.py @@ -50,6 +50,8 @@ def test_default_config(): expected_keys = { "use_table_report_data_ops", "data_dir", + "cache_dir", + "memory", "table_report_verbosity", "max_plot_columns", "max_association_columns", From c7f4a09cfabbbb441fa511ba6352d0c3f255afda Mon Sep 17 00:00:00 2001 From: Jerome Dockes Date: Fri, 3 Apr 2026 08:39:04 +0200 Subject: [PATCH 2/2] _ --- doc/guides/utilities/customizing_configuration.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/guides/utilities/customizing_configuration.rst b/doc/guides/utilities/customizing_configuration.rst index c0b38a286..a1b140fd8 100644 --- a/doc/guides/utilities/customizing_configuration.rst +++ b/doc/guides/utilities/customizing_configuration.rst @@ -39,7 +39,7 @@ are available by using >>> import skrub >>> config = skrub.get_config() >>> config.keys() -dict_keys(['use_table_report_data_ops', 'table_report_verbosity', 'max_plot_columns', 'max_association_columns', 'subsampling_seed', 'enable_subsampling', 'float_precision', 'cardinality_threshold', 'data_dir', 'eager_data_ops']) +dict_keys(['use_table_report_data_ops', 'table_report_verbosity', 'max_plot_columns', 'max_association_columns', 'subsampling_seed', 'enable_subsampling', 'float_precision', 'cardinality_threshold', 'data_dir', 'cache_dir', 'memory', 'eager_data_ops']) These are the parameters currently available in the global configuration: