Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions skrub/_reporting/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,14 @@ def _adjust_fig_size(fig, ax, target_w, target_h):


def _get_range(values, frac=0.2, factor=3.0):
if np.issubdtype(values.dtype, np.floating):
finite_values = values[np.isfinite(values)]
else:
finite_values = values
if not len(finite_values):
return 0, 0
min_value, low_p, high_p, max_value = np.quantile(
values, [0.0, frac, 1.0 - frac, 1.0]
finite_values, [0.0, frac, 1.0 - frac, 1.0]
)
delta = high_p - low_p
if not delta:
Expand All @@ -193,7 +199,22 @@ def _get_range(values, frac=0.2, factor=3.0):
return low, high


def _get_safe_hist_range(values):
# make sure numpy can find bin edges between the range low and high bounds
if not len(values) or not (
np.issubdtype(values.dtype, np.floating)
or np.issubdtype(values.dtype, np.integer)
):
return None
vmin, vmax = values.min(), values.max()
delta = max(np.spacing(vmin), np.spacing(vmax))
if vmax - vmin > 12 * delta:
return None
return vmin - 6 * delta, vmax + 6 * delta


def _robust_hist(col, ax=None, color=None):
result = {}
col = sbd.drop_nulls(col)
if sbd.is_float(col):
# avoid any issues with pandas nullable dtypes
Expand All @@ -211,19 +232,25 @@ def _robust_hist(col, ax=None, color=None):
np_histogram_values = sbd.to_numpy(
_datetime_encoder.DatetimeEncoder(resolution=None).fit_transform(col)
).ravel()
result["total_seconds_offset"] = np_histogram_values.min()
np_histogram_values = np_histogram_values - result["total_seconds_offset"]
else:
np_histogram_values = values
low, high = _get_range(values)
inlier_mask = (low <= values) & (values <= high)
n_low_outliers = (values < low).sum()
n_high_outliers = (high < values).sum()
result = {"n_low_outliers": n_low_outliers, "n_high_outliers": n_high_outliers}
result.update(n_low_outliers=n_low_outliers, n_high_outliers=n_high_outliers)
np_histogram_inliers = np_histogram_values[inlier_mask]
result["bin_counts"], result["bin_edges"] = np.histogram(
np_histogram_values[inlier_mask]
np_histogram_inliers, range=_get_safe_hist_range(np_histogram_inliers)
)
if ax is None:
return result
n, bins, patches = ax.hist(values[inlier_mask])
histogram_inliers = values[inlier_mask]
n, bins, patches = ax.hist(
histogram_inliers, range=_get_safe_hist_range(histogram_inliers)
)
n_out = n_low_outliers + n_high_outliers
if not n_out:
return result
Expand Down
5 changes: 5 additions & 0 deletions skrub/_reporting/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,8 @@ def test_histogram():
data = pd.Series([0.0])
_, hist = _plotting.histogram(data)
assert (hist["n_low_outliers"], hist["n_high_outliers"]) == (0, 0)

low = np.float32(10.0)
high = np.nextafter(low, 11.0)
data = pd.Series([low, high])
_, hist = _plotting.histogram(data)
Loading