Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions skrub/_table_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,39 @@ def _get_preprocessors(
return steps


def _list_transformations(estimator):
message = ""
for step in estimator._pipeline.named_steps:
if step == "checkinputdataframe":
continue
transformer = estimator._pipeline.named_steps[step]
match transformer.transformer:
case DropUninformative():
dropped = set(transformer.all_inputs_) - set(transformer.all_outputs_)
if dropped != set():
message += "DropUninformative - " + "\n"
message += f"Dropped columns {dropped}" + "\n"
message += f"Used inputs: {transformer.used_inputs_}" + "\n"
case ToFloat():
message += "ToFloat - " + "\n"
message += (
f"Columns transformed to float: {transformer.used_inputs_}" + "\n"
)
case ToDatetime():
message += "ToDatetime - " + "\n"
message += (
f"Columns transformed to datetime: {transformer.used_inputs_}"
+ "\n"
)
case CleanNullStrings():
message += "CleanNullStrings - " + "\n"
message += (
f"Columns with standardized nulls: {transformer.used_inputs_}"
+ "\n"
)
return message


class Cleaner(TransformerMixin, BaseEstimator):
"""Column-wise consistency checks and sanitization of dtypes, null values and dates.

Expand Down Expand Up @@ -538,6 +571,9 @@ def get_feature_names_out(self, input_features=None):
check_is_fitted(self, "all_outputs_")
return np.asarray(self.all_outputs_)

def list_transformations(self):
return _list_transformations(self)


class TableVectorizer(TransformerMixin, BaseEstimator):
"""Transform a dataframe to a numeric (vectorized) representation.
Expand Down Expand Up @@ -1164,3 +1200,40 @@ def get_feature_names_out(self, input_features=None):
"""
check_is_fitted(self, "all_outputs_")
return np.asarray(self.all_outputs_)

def list_transformations(self):
preprocessing_transformations = _list_transformations(self)
vectorize_transformations = ""
specific_transformations = ""

all_transformers = self.kind_to_columns_
specific = all_transformers.pop("specific")

for transformer_type, transformer_cols in all_transformers.items():
if transformer_cols != []:
vectorize_transformations += (
f"{transformer_type} transformer is \
{getattr(self, transformer_type)} \
and was applied to {transformer_cols}."
+ "\n"
)
else:
vectorize_transformations += (
f"{transformer_type} transformer is \
{getattr(self, transformer_type)} \
and was applied to nothing."
+ "\n"
)

if self.specific_transformers != ():
for t in self.specific_transformers:
specific_transformations += f"specific transformer \
{t} was applied to {specific}"

return (
preprocessing_transformations
+ "\n\n"
+ vectorize_transformations
+ "\n\n"
+ specific_transformations
)
13 changes: 13 additions & 0 deletions skrub/tests/test_table_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from skrub._to_float import ToFloat
from skrub._to_str import ToStr
from skrub.conftest import _POLARS_INSTALLED
from skrub.datasets._generating import toy_cities

MSG_PANDAS_DEPRECATED_WARNING = "Skip deprecation warning"

Expand Down Expand Up @@ -1277,3 +1278,15 @@ def test_duration_to_float(df_module):
vectorizer = Cleaner()
transformed = vectorizer.fit_transform(df)
df_module.assert_column_equal(transformed["duration"], df["duration"])


def test_list_transformations(df_module):
df = toy_cities()

vectorizer = TableVectorizer()
_ = vectorizer.fit_transform(df)
_ = vectorizer.list_transformations()

vectorizer = Cleaner()
_ = vectorizer.fit_transform(df)
_ = vectorizer.list_transformations()