Skip to content

Commit b6cd55a

Browse files
authored
fix: make invalid location warning case-insensitive (#1044)
* fix: make invalid location warning case-insensitive * fix failing unit test * add system tests for non canonical location setting
1 parent d204603 commit b6cd55a

File tree

4 files changed

+73
-32
lines changed

4 files changed

+73
-32
lines changed

‎bigframes/_config/bigquery_options.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,36 @@
3636
UNKNOWN_LOCATION_MESSAGE = "The location '{location}' is set to an unknown value. Did you mean '{possibility}'?"
3737

3838

39-
def _validate_location(value: Optional[str]):
40-
41-
if value is None:
42-
return
43-
44-
if value not in bigframes.constants.ALL_BIGQUERY_LOCATIONS:
45-
location = str(value)
46-
possibility = min(
47-
bigframes.constants.ALL_BIGQUERY_LOCATIONS,
48-
key=lambda item: jellyfish.levenshtein_distance(location, item),
49-
)
50-
warnings.warn(
51-
UNKNOWN_LOCATION_MESSAGE.format(location=location, possibility=possibility),
52-
# There are many layers before we get to (possibly) the user's code:
53-
# -> bpd.options.bigquery.location = "us-central-1"
54-
# -> location.setter
55-
# -> _validate_location
56-
stacklevel=3,
57-
category=bigframes.exceptions.UnknownLocationWarning,
58-
)
39+
def _get_validated_location(value: Optional[str]) -> Optional[str]:
40+
41+
if value is None or value in bigframes.constants.ALL_BIGQUERY_LOCATIONS:
42+
return value
43+
44+
location = str(value)
45+
46+
location_lowercase = location.lower()
47+
if location_lowercase in bigframes.constants.BIGQUERY_REGIONS:
48+
return location_lowercase
49+
50+
location_uppercase = location.upper()
51+
if location_uppercase in bigframes.constants.BIGQUERY_MULTIREGIONS:
52+
return location_uppercase
53+
54+
possibility = min(
55+
bigframes.constants.ALL_BIGQUERY_LOCATIONS,
56+
key=lambda item: jellyfish.levenshtein_distance(location, item),
57+
)
58+
warnings.warn(
59+
UNKNOWN_LOCATION_MESSAGE.format(location=location, possibility=possibility),
60+
# There are many layers before we get to (possibly) the user's code:
61+
# -> bpd.options.bigquery.location = "us-central-1"
62+
# -> location.setter
63+
# -> _get_validated_location
64+
stacklevel=3,
65+
category=bigframes.exceptions.UnknownLocationWarning,
66+
)
67+
68+
return value
5969

6070

6171
def _validate_ordering_mode(value: str) -> bigframes.enums.OrderingMode:
@@ -135,8 +145,7 @@ def location(self) -> Optional[str]:
135145
def location(self, value: Optional[str]):
136146
if self._session_started and self._location != value:
137147
raise ValueError(SESSION_STARTED_MESSAGE.format(attribute="location"))
138-
_validate_location(value)
139-
self._location = value
148+
self._location = _get_validated_location(value)
140149

141150
@property
142151
def project(self) -> Optional[str]:

‎bigframes/constants.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@
2222
DEFAULT_EXPIRATION = datetime.timedelta(days=7)
2323

2424
# https://cloud.google.com/bigquery/docs/locations
25-
ALL_BIGQUERY_LOCATIONS = frozenset(
25+
BIGQUERY_REGIONS = frozenset(
2626
{
27-
# regions
2827
"us-east5",
2928
"us-south1",
3029
"us-central1",
@@ -68,11 +67,15 @@
6867
"me-central1",
6968
"me-west1",
7069
"africa-south1",
71-
# multi-regions
70+
}
71+
)
72+
BIGQUERY_MULTIREGIONS = frozenset(
73+
{
7274
"US",
7375
"EU",
7476
}
7577
)
78+
ALL_BIGQUERY_LOCATIONS = frozenset(BIGQUERY_REGIONS.union(BIGQUERY_MULTIREGIONS))
7679

7780
# https://cloud.google.com/storage/docs/regional-endpoints
7881
REP_ENABLED_BIGQUERY_LOCATIONS = frozenset(

‎tests/system/large/test_location.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
import bigframes.session.clients
2323

2424

25-
def _assert_bq_execution_location(session: bigframes.Session):
25+
def _assert_bq_execution_location(
26+
session: bigframes.Session, expected_location: typing.Optional[str] = None
27+
):
2628
df = session.read_gbq(
2729
"""
2830
SELECT "aaa" as name, 111 as number
@@ -33,10 +35,10 @@ def _assert_bq_execution_location(session: bigframes.Session):
3335
"""
3436
)
3537

36-
assert (
37-
typing.cast(bigquery.QueryJob, df.query_job).location
38-
== session.bqclient.location
39-
)
38+
if expected_location is None:
39+
expected_location = session._location
40+
41+
assert typing.cast(bigquery.QueryJob, df.query_job).location == expected_location
4042

4143
result = (
4244
df[["name", "number"]]
@@ -47,8 +49,7 @@ def _assert_bq_execution_location(session: bigframes.Session):
4749
)
4850

4951
assert (
50-
typing.cast(bigquery.QueryJob, result.query_job).location
51-
== session.bqclient.location
52+
typing.cast(bigquery.QueryJob, result.query_job).location == expected_location
5253
)
5354

5455

@@ -87,6 +88,30 @@ def test_bq_location(bigquery_location):
8788
_assert_bq_execution_location(session)
8889

8990

91+
@pytest.mark.parametrize(
92+
("set_location", "resolved_location"),
93+
# Sort the set to avoid nondeterminism.
94+
[
95+
(loc.capitalize(), loc)
96+
for loc in sorted(bigframes.constants.ALL_BIGQUERY_LOCATIONS)
97+
],
98+
)
99+
def test_bq_location_non_canonical(set_location, resolved_location):
100+
session = bigframes.Session(
101+
context=bigframes.BigQueryOptions(location=set_location)
102+
)
103+
104+
assert session.bqclient.location == set_location
105+
106+
# by default global endpoint is used
107+
assert (
108+
session.bqclient._connection.API_BASE_URL == "https://bigquery.googleapis.com"
109+
)
110+
111+
# assert that bigframes session honors the location
112+
_assert_bq_execution_location(session, resolved_location)
113+
114+
90115
@pytest.mark.parametrize(
91116
"bigquery_location",
92117
# Sort the set to avoid nondeterminism.

‎tests/unit/_config/test_bigquery_options.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ def test_setter_if_session_started_but_setting_the_same_value(attribute):
9090
[
9191
(None,),
9292
("us-central1",),
93+
("us-Central1",),
94+
("US-CENTRAL1",),
95+
("US",),
96+
("us",),
9397
],
9498
)
9599
def test_location_set_to_valid_no_warning(valid_location):

0 commit comments

Comments
 (0)