Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions packages/bigframes/bigframes/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@
import google.cloud.bigquery._job_helpers
import google.cloud.bigquery.job.query
import google.cloud.bigquery.table
from google.cloud.bigquery.job.query import QueryPlanEntry

import bigframes.session.executor

_FALLBACK_TO_GLOBAL = "fallback_to_global"


class Subscriber:
def __init__(self, callback: Callable[[Event], None], *, publisher: Publisher):
def __init__(
self, callback: Callable[[Event], None], *, publisher: Publisher
): # noqa: E501
self._publisher = publisher
self._callback = callback
self._subscriber_id = uuid.uuid4()
Expand Down Expand Up @@ -125,15 +130,21 @@ class BigQuerySentEvent(ExecutionRunning):
location: Optional[str] = None
job_id: Optional[str] = None
request_id: Optional[str] = None
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL

@classmethod
def from_bqclient(cls, event: google.cloud.bigquery._job_helpers.QuerySentEvent):
def from_bqclient(
cls,
event: google.cloud.bigquery._job_helpers.QuerySentEvent,
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL,
):
return cls(
query=event.query,
billing_project=event.billing_project,
location=event.location,
job_id=event.job_id,
request_id=event.request_id,
progress_bar=progress_bar,
)


Expand All @@ -146,15 +157,21 @@ class BigQueryRetryEvent(ExecutionRunning):
location: Optional[str] = None
job_id: Optional[str] = None
request_id: Optional[str] = None
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL

@classmethod
def from_bqclient(cls, event: google.cloud.bigquery._job_helpers.QueryRetryEvent):
def from_bqclient(
cls,
event: google.cloud.bigquery._job_helpers.QueryRetryEvent,
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL,
):
return cls(
query=event.query,
billing_project=event.billing_project,
location=event.location,
job_id=event.job_id,
request_id=event.request_id,
progress_bar=progress_bar,
)


Expand All @@ -167,14 +184,17 @@ class BigQueryReceivedEvent(ExecutionRunning):
job_id: Optional[str] = None
statement_type: Optional[str] = None
state: Optional[str] = None
query_plan: Optional[list[google.cloud.bigquery.job.query.QueryPlanEntry]] = None
query_plan: Optional[list[QueryPlanEntry]] = None
created: Optional[datetime.datetime] = None
started: Optional[datetime.datetime] = None
ended: Optional[datetime.datetime] = None
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL

@classmethod
def from_bqclient(
cls, event: google.cloud.bigquery._job_helpers.QueryReceivedEvent
cls,
event: google.cloud.bigquery._job_helpers.QueryReceivedEvent,
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL,
):
return cls(
billing_project=event.billing_project,
Expand All @@ -186,6 +206,7 @@ def from_bqclient(
created=event.created,
started=event.started,
ended=event.ended,
progress_bar=progress_bar,
)


Expand All @@ -204,10 +225,13 @@ class BigQueryFinishedEvent(ExecutionRunning):
created: Optional[datetime.datetime] = None
started: Optional[datetime.datetime] = None
ended: Optional[datetime.datetime] = None
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL

@classmethod
def from_bqclient(
cls, event: google.cloud.bigquery._job_helpers.QueryFinishedEvent
cls,
event: google.cloud.bigquery._job_helpers.QueryFinishedEvent,
progress_bar: Optional[str] = _FALLBACK_TO_GLOBAL,
):
return cls(
billing_project=event.billing_project,
Expand All @@ -221,6 +245,7 @@ def from_bqclient(
created=event.created,
started=event.started,
ended=event.ended,
progress_bar=progress_bar,
)


Expand Down
7 changes: 6 additions & 1 deletion packages/bigframes/bigframes/formatting_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,12 @@ def progress_callback(
# This will allow cleanup to continue.
return

progress_bar = bigframes._config.options.display.progress_bar
# Prioritize progress_bar set on the event, falling back to thread-local option.
progress_bar = getattr(
event, "progress_bar", bigframes.core.events._FALLBACK_TO_GLOBAL
)
if progress_bar == bigframes.core.events._FALLBACK_TO_GLOBAL:
progress_bar = bigframes._config.options.display.progress_bar

if progress_bar == "auto":
progress_bar = "notebook" if in_ipython() else "terminal"
Expand Down
39 changes: 24 additions & 15 deletions packages/bigframes/bigframes/session/_io/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,18 +245,23 @@ def add_and_trim_labels(job_config, session=None):


def create_bq_event_callback(publisher):
def publish_bq_event(event):
if isinstance(event, google.cloud.bigquery._job_helpers.QueryFinishedEvent):
bf_event = bigframes.core.events.BigQueryFinishedEvent.from_bqclient(event)
elif isinstance(event, google.cloud.bigquery._job_helpers.QueryReceivedEvent):
bf_event = bigframes.core.events.BigQueryReceivedEvent.from_bqclient(event)
elif isinstance(event, google.cloud.bigquery._job_helpers.QueryRetryEvent):
bf_event = bigframes.core.events.BigQueryRetryEvent.from_bqclient(event)
elif isinstance(event, google.cloud.bigquery._job_helpers.QuerySentEvent):
bf_event = bigframes.core.events.BigQuerySentEvent.from_bqclient(event)
else:
bf_event = bigframes.core.events.BigQueryUnknownEvent(event)
import bigframes._config

progress_bar = bigframes._config.options.display.progress_bar

event_map = {
google.cloud.bigquery._job_helpers.QueryFinishedEvent: bigframes.core.events.BigQueryFinishedEvent,
google.cloud.bigquery._job_helpers.QueryReceivedEvent: bigframes.core.events.BigQueryReceivedEvent,
google.cloud.bigquery._job_helpers.QueryRetryEvent: bigframes.core.events.BigQueryRetryEvent,
google.cloud.bigquery._job_helpers.QuerySentEvent: bigframes.core.events.BigQuerySentEvent,
}

def publish_bq_event(event):
bf_event = bigframes.core.events.BigQueryUnknownEvent(event)
for bq_type, bf_type in event_map.items():
if isinstance(event, bq_type):
bf_event = bf_type.from_bqclient(event, progress_bar=progress_bar) # type: ignore
break
publisher.publish(bf_event)

return publish_bq_event
Expand All @@ -275,7 +280,8 @@ def start_query_with_client(
query_with_job: Literal[True],
publisher: bigframes.core.events.Publisher,
session=None,
) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]: ...
) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]:
...


@overload
Expand All @@ -291,7 +297,8 @@ def start_query_with_client(
query_with_job: Literal[False],
publisher: bigframes.core.events.Publisher,
session=None,
) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]: ...
) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]:
...


@overload
Expand All @@ -308,7 +315,8 @@ def start_query_with_client(
job_retry: google.api_core.retry.Retry,
publisher: bigframes.core.events.Publisher,
session=None,
) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]: ...
) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]:
...


@overload
Expand All @@ -325,7 +333,8 @@ def start_query_with_client(
job_retry: google.api_core.retry.Retry,
publisher: bigframes.core.events.Publisher,
session=None,
) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]: ...
) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]:
...


def start_query_with_client(
Expand Down
17 changes: 17 additions & 0 deletions packages/bigframes/tests/system/small/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,23 @@ def test_progress_bar_load_jobs(
assert_loading_msg_exist(capsys.readouterr().out, pattern="Load")


def test_progress_bar_uniqueness_check(session: bf.Session, capsys):
# Ensure strictly_ordered is True (default) to trigger uniqueness check
assert session._strictly_ordered

capsys.readouterr() # clear output

with bf.option_context("display.progress_bar", "terminal"):
# Read a table and specify a non-unique index_col to trigger the check.
# We use a public table to make it a "real" test.
session.read_gbq_table(
"bigquery-public-data.ml_datasets.penguins",
index_col="island",
)

assert_loading_msg_exist(capsys.readouterr().out)


def assert_loading_msg_exist(capstdout: str, pattern=job_load_message_regex):
num_loading_msg = 0
lines = capstdout.split("\n")
Expand Down
25 changes: 25 additions & 0 deletions packages/bigframes/tests/unit/test_formatting_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,28 @@ def test_get_job_url():
job_id=job_id, location=location, project_id=project_id
)
assert actual_url == expected_url


def test_progress_callback_respects_event_progress_bar():
event = bfevents.BigQuerySentEvent(
query="SELECT * FROM my_table",
progress_bar=None,
)

with mock.patch("bigframes._config.options.display.progress_bar", "terminal"):
with mock.patch("bigframes.formatting_helpers.in_ipython", return_value=False):
with mock.patch("builtins.print") as mock_print:
formatting_helpers.progress_callback(event)
mock_print.assert_not_called()


def test_progress_callback_falls_back_to_global():
event = bfevents.BigQuerySentEvent(
query="SELECT * FROM my_table",
)

with mock.patch("bigframes._config.options.display.progress_bar", "terminal"):
with mock.patch("bigframes.formatting_helpers.in_ipython", return_value=False):
with mock.patch("builtins.print") as mock_print:
formatting_helpers.progress_callback(event)
mock_print.assert_called_once()
Loading