Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions skrub/_tabular_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from sklearn import ensemble
from sklearn.base import BaseEstimator
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline as skpipeline
Comment thread
MarieSacksick marked this conversation as resolved.
Outdated
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import OrdinalEncoder

Expand Down Expand Up @@ -48,6 +49,7 @@ def tabular_pipeline(estimator, *, n_jobs=None):
Parameters
----------
estimator : {"regressor", "regression", "classifier", "classification"} or sklearn.base.BaseEstimator
or sklearn.pipeline.Pipeline
The estimator to use as the final step in the pipeline. Based on the type of
estimator, the previous preprocessing steps and their respective parameters are
chosen. The possible values are:
Expand All @@ -59,6 +61,7 @@ def tabular_pipeline(estimator, *, n_jobs=None):
:obj:`~sklearn.ensemble.HistGradientBoostingClassifier` is used as the final
step;
- a scikit-learn estimator: the provided estimator is used as the final step.
- a scikit-learn pipeline : the last step of the pipeline is the estimator used as the final step.
Comment thread
khaoulariad marked this conversation as resolved.
Outdated

n_jobs : int, default=None
Number of jobs to run in parallel in the :obj:`TableVectorizer` step. ``None``
Expand Down Expand Up @@ -225,6 +228,11 @@ def tabular_pipeline(estimator, *, n_jobs=None):
vectorizer = TableVectorizer(n_jobs=n_jobs)
cat_feat_kwargs = {"categorical_features": "from_dtype"}

if isinstance(estimator, skpipeline):
return tabular_pipeline(
estimator.steps[-1][-1],
n_jobs=n_jobs,
)
if isinstance(estimator, str):
if estimator in ("classifier", "classification"):
return tabular_pipeline(
Expand Down
24 changes: 22 additions & 2 deletions skrub/tests/test_tabular_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
from sklearn import ensemble
from sklearn.impute import SimpleImputer
from sklearn.linear_model import Ridge
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.pipeline import Pipeline as skpipeline
Comment thread
MarieSacksick marked this conversation as resolved.
Outdated
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder

from skrub import (
Expand All @@ -14,7 +15,13 @@


@pytest.mark.parametrize(
"learner_kind", ["regressor", "regression", "classifier", "classification"]
"learner_kind",

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a big deal at all but in general please try to avoid changes that are unrelated to your Pull Request to make it easier to review and avoid cluttering the git history

[
"regressor",
"regression",
"classifier",
"classification",
],
)
def test_default_pipeline(learner_kind):
p = tabular_pipeline(learner_kind)
Expand Down Expand Up @@ -74,3 +81,16 @@ def test_from_dtype():
ensemble.HistGradientBoostingRegressor(categorical_features="from_dtype")
)
assert isinstance(p.named_steps["tablevectorizer"].low_cardinality, ToCategorical)


def test_skpipeline_learner():

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_skpipeline_learner():
def test_estimator_is_a_pipeline():

original_learner = LogisticRegression()
sk_pipeline = skpipeline([("imputer", SimpleImputer()), ("clf", original_learner)])
p = tabular_pipeline(sk_pipeline)
Comment thread
khaoulariad marked this conversation as resolved.
Outdated
tv, imputer, scaler, learner = (e for _, e in p.steps)
assert learner is original_learner
assert isinstance(tv.high_cardinality, StringEncoder)
assert isinstance(tv.low_cardinality, OneHotEncoder)
assert isinstance(imputer, SimpleImputer)
assert isinstance(scaler, SquashingScaler)
Comment thread
khaoulariad marked this conversation as resolved.
Outdated
assert tv.datetime.periodic_encoding == "spline"
Comment thread
khaoulariad marked this conversation as resolved.
Outdated
Loading