Skip to content

fix: Properly identify non-unique index in non-pk tables #699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 17, 2024
Merged
2 changes: 1 addition & 1 deletion bigframes/core/compile/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ def to_sql(
output_columns = [
col_id_overrides.get(col, col) for col in baked_ir.column_ids
]
sql = bigframes.core.sql.select_from(output_columns, sql)
sql = bigframes.core.sql.select_from_subquery(output_columns, sql)

# Single row frames may not have any ordering columns
if len(baked_ir._ordering.all_ordering_columns) > 0:
Expand Down
19 changes: 15 additions & 4 deletions bigframes/core/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def infix_op(opname: str, left_arg: str, right_arg: str):


### Writing SELECT expressions
def select_from(columns: Iterable[str], subquery: str, distinct: bool = False):
def select_from_subquery(columns: Iterable[str], subquery: str, distinct: bool = False):
selection = ", ".join(map(identifier, columns))
distinct_clause = "DISTINCT " if distinct else ""

Expand All @@ -120,16 +120,27 @@ def select_from(columns: Iterable[str], subquery: str, distinct: bool = False):
)


def select_from_table_ref(
columns: Iterable[str], table_ref: bigquery.TableReference, distinct: bool = False
):
selection = ", ".join(map(identifier, columns))
distinct_clause = "DISTINCT " if distinct else ""

return textwrap.dedent(
f"SELECT {distinct_clause}{selection}\nFROM {table_reference(table_ref)}"
)


def select_table(table_ref: bigquery.TableReference):
return textwrap.dedent(f"SELECT * FROM {table_reference(table_ref)}")


def is_distinct_sql(columns: Iterable[str], table_sql: str) -> str:
def is_distinct_sql(columns: Iterable[str], table_ref: bigquery.TableReference) -> str:
is_unique_sql = f"""WITH full_table AS (
{select_from(columns, table_sql)}
{select_from_table_ref(columns, table_ref)}
),
distinct_table AS (
{select_from(columns, table_sql, distinct=True)}
{select_from_table_ref(columns, table_ref, distinct=True)}
)

SELECT (SELECT COUNT(*) FROM full_table) AS `total_count`,
Expand Down
1 change: 0 additions & 1 deletion bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,6 @@ def _read_gbq_table(
# check.
is_index_unique = bf_read_gbq_table.are_index_cols_unique(
bqclient=self.bqclient,
ibis_client=self.ibis_client,
table=table,
index_cols=index_cols,
api_name=api_name,
Expand Down
8 changes: 4 additions & 4 deletions bigframes/session/_io/bigquery/read_gbq_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,21 +162,21 @@ def get_ibis_time_travel_table(

def are_index_cols_unique(
bqclient: bigquery.Client,
ibis_client: ibis.BaseBackend,
table: bigquery.table.Table,
index_cols: List[str],
api_name: str,
) -> bool:
if len(index_cols) == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add some unit tests specificaly for this method with index_cols = [] and primary_keys = []; index_cols = ['some_col']

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some unit tests

return False
# If index_cols contain the primary_keys, the query engine assumes they are
# provide a unique index.
primary_keys = frozenset(_get_primary_keys(table))
if primary_keys <= frozenset(index_cols):
if (len(primary_keys) > 0) and primary_keys <= frozenset(index_cols):
return True

# TODO(b/337925142): Avoid a "SELECT *" subquery here by ensuring
# table_expression only selects just index_cols.
table_sql = ibis_client.compile(table)
is_unique_sql = bigframes.core.sql.is_distinct_sql(index_cols, table_sql)
is_unique_sql = bigframes.core.sql.is_distinct_sql(index_cols, table.reference)
job_config = bigquery.QueryJobConfig()
job_config.labels["bigframes-api"] = api_name
results = bqclient.query_and_wait(is_unique_sql, job_config=job_config)
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/session/test_read_gbq_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
"""Unit tests for read_gbq_table helper functions."""

import datetime
import unittest.mock as mock

import google.cloud.bigquery
import google.cloud.bigquery as bigquery
import pytest

import bigframes.session._io.bigquery.read_gbq_table as bf_read_gbq_table

Expand Down Expand Up @@ -45,3 +48,71 @@ def test_get_ibis_time_travel_table_doesnt_timetravel_anonymous_datasets():

# Need fully-qualified table name.
assert "my-test-project" in sql


@pytest.mark.parametrize(
("index_cols", "primary_keys", "values_distinct", "expected"),
(
(["col1", "col2"], ["col1", "col2", "col3"], False, False),
(["col1", "col2", "col3"], ["col1", "col2", "col3"], True, True),
(
["col2", "col3", "col1"],
[
"col3",
"col2",
],
True,
True,
),
(["col1", "col2"], [], False, False),
([], ["col1", "col2", "col3"], False, False),
([], [], False, False),
),
)
def test_are_index_cols_unique(index_cols, primary_keys, values_distinct, expected):
"""If a primary key is set on the table, we use that as the index column
by default, no error should be raised in this case.

See internal issue 335727141.
"""
table = google.cloud.bigquery.Table.from_api_repr(
{
"tableReference": {
"projectId": "my-project",
"datasetId": "my_dataset",
"tableId": "my_table",
},
"clustering": {
"fields": ["col1", "col2"],
},
},
)
table.schema = (
google.cloud.bigquery.SchemaField("col1", "INT64"),
google.cloud.bigquery.SchemaField("col2", "INT64"),
google.cloud.bigquery.SchemaField("col3", "INT64"),
google.cloud.bigquery.SchemaField("col4", "INT64"),
)

# TODO(b/305264153): use setter for table_constraints in client library
# when available.
table._properties["tableConstraints"] = {
"primaryKey": {
"columns": primary_keys,
},
}
bqclient = mock.create_autospec(google.cloud.bigquery.Client, instance=True)
bqclient.project = "test-project"
bqclient.get_table.return_value = table

bqclient.query_and_wait.return_value = (
{"total_count": 3, "distinct_count": 3 if values_distinct else 2},
)
session = resources.create_bigquery_session(
bqclient=bqclient, table_schema=table.schema
)
table._properties["location"] = session._location

result = bf_read_gbq_table.are_index_cols_unique(bqclient, table, index_cols, "")

assert result == expected
5 changes: 5 additions & 0 deletions tests/unit/session/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def get_table_mock(table_ref):
return table

session.bqclient.get_table = get_table_mock
session.bqclient.query_and_wait.return_value = (
{"total_count": 3, "distinct_count": 2},
)

with pytest.warns(UserWarning, match=re.escape("use_cache=False")):
df = session.read_gbq("my-project.my_dataset.my_table")
Expand All @@ -200,6 +203,7 @@ def test_default_index_warning_raised_by_read_gbq(table):
bqclient = mock.create_autospec(google.cloud.bigquery.Client, instance=True)
bqclient.project = "test-project"
bqclient.get_table.return_value = table
bqclient.query_and_wait.return_value = ({"total_count": 3, "distinct_count": 2},)
session = resources.create_bigquery_session(bqclient=bqclient)
table._properties["location"] = session._location

Expand All @@ -222,6 +226,7 @@ def test_default_index_warning_not_raised_by_read_gbq_index_col_sequential_int64
bqclient = mock.create_autospec(google.cloud.bigquery.Client, instance=True)
bqclient.project = "test-project"
bqclient.get_table.return_value = table
bqclient.query_and_wait.return_value = ({"total_count": 4, "distinct_count": 3},)
session = resources.create_bigquery_session(bqclient=bqclient)
table._properties["location"] = session._location

Expand Down