Skip to content

fix: Extend row hash to 128 bits to guarantee unique row id #632

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 26, 2024
50 changes: 35 additions & 15 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import google.api_core.gapic_v1.client_info
import google.auth.credentials
import google.cloud.bigquery as bigquery
import google.cloud.bigquery.table
import google.cloud.bigquery_connection_v1
import google.cloud.bigquery_storage_v1
import google.cloud.functions_v2
Expand Down Expand Up @@ -693,7 +694,7 @@ def read_gbq_table(

def _get_snapshot_sql_and_primary_key(
self,
table_ref: bigquery.table.TableReference,
table: google.cloud.bigquery.table.Table,
*,
api_name: str,
use_cache: bool = True,
Expand All @@ -709,7 +710,7 @@ def _get_snapshot_sql_and_primary_key(
table,
) = bigframes_io.get_snapshot_datetime_and_table_metadata(
self.bqclient,
table_ref=table_ref,
table_ref=table.reference,
api_name=api_name,
cache=self._df_snapshot,
use_cache=use_cache,
Expand All @@ -735,7 +736,7 @@ def _get_snapshot_sql_and_primary_key(

try:
table_expression = self.ibis_client.sql(
bigframes_io.create_snapshot_sql(table_ref, snapshot_timestamp)
bigframes_io.create_snapshot_sql(table.reference, snapshot_timestamp)
)
except google.api_core.exceptions.Forbidden as ex:
if "Drive credentials" in ex.message:
Expand Down Expand Up @@ -763,8 +764,9 @@ def _read_gbq_table(
query, default_project=self.bqclient.project
)

table = self.bqclient.get_table(table_ref)
(table_expression, primary_keys,) = self._get_snapshot_sql_and_primary_key(
table_ref, api_name=api_name, use_cache=use_cache
table, api_name=api_name, use_cache=use_cache
)
total_ordering_cols = primary_keys

Expand Down Expand Up @@ -836,9 +838,13 @@ def _read_gbq_table(
ordering=ordering,
)
else:
array_value = self._create_total_ordering(table_expression)
array_value = self._create_total_ordering(
table_expression, table_rows=table.num_rows
)
else:
array_value = self._create_total_ordering(table_expression)
array_value = self._create_total_ordering(
table_expression, table_rows=table.num_rows
)

value_columns = [col for col in array_value.column_ids if col not in index_cols]
block = blocks.Block(
Expand Down Expand Up @@ -1459,10 +1465,19 @@ def _create_empty_temp_table(
def _create_total_ordering(
self,
table: ibis_types.Table,
table_rows: Optional[int],
) -> core.ArrayValue:
# Since this might also be used as the index, don't use the default
# "ordering ID" name.

# For small tables, 64 bits is enough to avoid collisions, 128 bits will never ever collide no matter what
# Assume table is large if table row count is unknown
use_double_hash = (
(table_rows is None) or (table_rows == 0) or (table_rows > 100000)
)

ordering_hash_part = guid.generate_guid("bigframes_ordering_")
ordering_hash_part2 = guid.generate_guid("bigframes_ordering_")
ordering_rand_part = guid.generate_guid("bigframes_ordering_")

# All inputs into hash must be non-null or resulting hash will be null
Expand All @@ -1475,25 +1490,30 @@ def _create_total_ordering(
else str_values[0]
)
full_row_hash = full_row_str.hash().name(ordering_hash_part)
# By modifying value slightly, we get another hash uncorrelated with the first
full_row_hash_p2 = (full_row_str + "_").hash().name(ordering_hash_part2)
# Used to disambiguate between identical rows (which will have identical hash)
random_value = ibis.random().name(ordering_rand_part)

order_values = (
[full_row_hash, full_row_hash_p2, random_value]
if use_double_hash
else [full_row_hash, random_value]
)

original_column_ids = table.columns
table_with_ordering = table.select(
itertools.chain(original_column_ids, [full_row_hash, random_value])
itertools.chain(original_column_ids, order_values)
)

ordering_ref1 = order.ascending_over(ordering_hash_part)
ordering_ref2 = order.ascending_over(ordering_rand_part)
ordering = order.ExpressionOrdering(
ordering_value_columns=(ordering_ref1, ordering_ref2),
total_ordering_columns=frozenset([ordering_hash_part, ordering_rand_part]),
ordering_value_columns=tuple(
order.ascending_over(col.get_name()) for col in order_values
),
total_ordering_columns=frozenset(col.get_name() for col in order_values),
)
columns = [table_with_ordering[col] for col in original_column_ids]
hidden_columns = [
table_with_ordering[ordering_hash_part],
table_with_ordering[ordering_rand_part],
]
hidden_columns = [table_with_ordering[col.get_name()] for col in order_values]
return core.ArrayValue.from_ibis(
self,
table_with_ordering,
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def create_bigquery_session(
google.auth.credentials.Credentials, instance=True
)

if anonymous_dataset is None:
anonymous_dataset = google.cloud.bigquery.DatasetReference(
"test-project",
"test_dataset",
)

if bqclient is None:
bqclient = mock.create_autospec(google.cloud.bigquery.Client, instance=True)
bqclient.project = "test-project"
Expand All @@ -53,6 +59,10 @@ def create_bigquery_session(
table._properties = {}
type(table).location = mock.PropertyMock(return_value="test-region")
type(table).schema = mock.PropertyMock(return_value=table_schema)
type(table).reference = mock.PropertyMock(
return_value=anonymous_dataset.table("test_table")
)
type(table).num_rows = mock.PropertyMock(return_value=1000000000)
bqclient.get_table.return_value = table

if anonymous_dataset is None:
Expand Down
21 changes: 17 additions & 4 deletions tests/unit/session/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ def test_read_gbq_cached_table():
table,
)

def get_table_mock(table_ref):
table = google.cloud.bigquery.Table(
table_ref, (google.cloud.bigquery.SchemaField("col", "INTEGER"),)
)
table._properties["numRows"] = "1000000000"
table._properties["location"] = session._location
return table

session.bqclient.get_table = get_table_mock

with pytest.warns(UserWarning, match=re.escape("use_cache=False")):
df = session.read_gbq("my-project.my_dataset.my_table")

Expand Down Expand Up @@ -137,10 +147,13 @@ def query_mock(query, *args, **kwargs):

session.bqclient.query = query_mock

def get_table_mock(dataset_ref):
dataset = google.cloud.bigquery.Dataset(dataset_ref)
dataset.location = session._location
return dataset
def get_table_mock(table_ref):
table = google.cloud.bigquery.Table(
table_ref, (google.cloud.bigquery.SchemaField("col", "INTEGER"),)
)
table._properties["numRows"] = 1000000000
table._properties["location"] = session._location
return table

session.bqclient.get_table = get_table_mock

Expand Down