Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.

This project adheres to [Semantic Versioning](http://semver.org/).

### Unreleased (or v0.18.1)
* Added optional `seed` parameter to `sample_proportions()` for reproducible results
* Added optional `seed` parameter to `proportions_from_distribution()` for reproducible results
* Fixed bug where `proportions_from_distribution()` ignored the `column_name` parameter

### v0.18.0
* Improved color contrast for charts (gold, blue)
* Fixed make_array() so it doesn't auto-convert booleans to integers
Expand Down
14 changes: 9 additions & 5 deletions datascience/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def plot_normal_cdf(rbound=None, lbound=None, mean=0, sd=1):
plot_cdf_area = plot_normal_cdf


def sample_proportions(sample_size: int, probabilities):
def sample_proportions(sample_size: int, probabilities, seed=None):
"""Return the proportion of random draws for each outcome in a distribution.

This function is similar to np.random.Generator.multinomial, but returns proportions
Expand All @@ -137,15 +137,17 @@ def sample_proportions(sample_size: int, probabilities):

``probabilities``: An array of probabilities that forms a distribution.

``seed``: Optional seed for reproducibility. If None, results will be random.

Returns:
An array with the same length as ``probability`` that sums to 1.
"""
rng = np.random.default_rng()
rng = np.random.default_rng(seed)
return rng.multinomial(sample_size, probabilities) / sample_size


def proportions_from_distribution(table, label, sample_size,
column_name='Random Sample'):
column_name='Random Sample', seed=None):
"""
Adds a column named ``column_name`` containing the proportions of a random
draw using the distribution in ``label``.
Expand All @@ -165,6 +167,8 @@ def proportions_from_distribution(table, label, sample_size,
``column_name``: The name of the new column that contains the sampled
proportions. Defaults to ``'Random Sample'``.

``seed``: Optional seed for reproducibility. If None, results will be random.

Returns:
A copy of ``table`` with a column ``column_name`` containing the
sampled proportions. The proportions will sum to 1.
Expand All @@ -173,8 +177,8 @@ def proportions_from_distribution(table, label, sample_size,
``ValueError``: If the ``label`` is not in the table, or if
``table.column(label)`` does not sum to 1.
"""
proportions = sample_proportions(sample_size, table.column(label))
return table.with_column('Random Sample', proportions)
proportions = sample_proportions(sample_size, table.column(label), seed)
return table.with_column(column_name, proportions)


def table_apply(table, func, subset=None):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,30 @@ def test_proportions_from_distribution():
assert [x in (0, 0.5, 1) for x in ds.sample_proportions(2, ds.make_array(.2, .3, .5))]


def test_sample_proportions_seed():
"""Test seed parameter and backward compatibility"""
result1 = ds.sample_proportions(1000, [0.5, 0.5], seed=42)
result2 = ds.sample_proportions(1000, [0.5, 0.5], seed=42)
assert np.array_equal(result1, result2)

result3 = ds.sample_proportions(1000, [0.5, 0.5], seed=99)
assert not np.array_equal(result1, result3)


def test_proportions_from_distribution_seed_and_column_name():
"""Test seed parameter and column_name bug fix"""
t = ds.Table().with_column('probs', [0.6, 0.4])

result1 = ds.proportions_from_distribution(t, 'probs', 1000, seed=42)
result2 = ds.proportions_from_distribution(t, 'probs', 1000, seed=42)
assert np.array_equal(result1.column(1), result2.column(1))
assert _round_eq(1, sum(result1.column(1)))

result3 = ds.proportions_from_distribution(t, 'probs', 1000, column_name='My Sample')
assert 'My Sample' in result3.labels
assert result3.num_columns == 2


def test_is_non_string_iterable():
is_string = 'hello'
assert ds.is_non_string_iterable(is_string) == False
Expand Down
Loading