Skip to content

Commit fca9f9e

Browse files
committed
feat: add index get_loc API
1 parent d5c54fc commit fca9f9e

File tree

3 files changed

+229
-0
lines changed

3 files changed

+229
-0
lines changed

‎bigframes/core/indexes/base.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,21 @@
2727
import pandas
2828

2929
from bigframes import dtypes
30+
from bigframes.core.array_value import ArrayValue
3031
import bigframes.core.block_transforms as block_ops
3132
import bigframes.core.blocks as blocks
3233
import bigframes.core.expression as ex
34+
import bigframes.core.identifiers as ids
35+
import bigframes.core.nodes as nodes
3336
import bigframes.core.ordering as order
3437
import bigframes.core.utils as utils
3538
import bigframes.core.validations as validations
39+
import bigframes.core.window_spec as window_spec
3640
import bigframes.dtypes
3741
import bigframes.formatting_helpers as formatter
3842
import bigframes.operations as ops
3943
import bigframes.operations.aggregations as agg_ops
44+
import bigframes.series
4045

4146
if typing.TYPE_CHECKING:
4247
import bigframes.dataframe
@@ -247,6 +252,95 @@ def query_job(self) -> bigquery.QueryJob:
247252
self._query_job = query_job
248253
return self._query_job
249254

255+
def get_loc(self, key):
256+
"""Get integer location, slice or boolean mask for requested label.
257+
258+
Args:
259+
key: The label to search for in the index.
260+
261+
Returns:
262+
An integer, slice, or boolean mask representing the location(s) of the key.
263+
264+
Raises:
265+
NotImplementedError: If the index has more than one level.
266+
KeyError: If the key is not found in the index.
267+
"""
268+
269+
if self.nlevels != 1:
270+
raise NotImplementedError("get_loc only supports single-level indexes")
271+
272+
# Get the index column from the block
273+
index_column = self._block.index_columns[0]
274+
275+
# Apply row numbering to the original data
276+
win_spec = window_spec.unbound()
277+
row_num_agg = ex.NullaryAggregation(agg_ops.RowNumberOp())
278+
row_num_col_id = ids.ColumnId.unique()
279+
280+
window_node = nodes.WindowOpNode(
281+
child=self._block._expr.node,
282+
expression=row_num_agg,
283+
window_spec=win_spec,
284+
output_name=row_num_col_id,
285+
never_skip_nulls=True,
286+
)
287+
288+
windowed_array = ArrayValue(window_node)
289+
windowed_block = self._block.__class__(
290+
windowed_array,
291+
index_columns=self._block.index_columns,
292+
column_labels=self._block.column_labels.insert(
293+
len(self._block.column_labels), None
294+
),
295+
index_labels=self._block._index_labels,
296+
)
297+
298+
# Create expression to find matching positions
299+
match_expr = ops.eq_op.as_expr(ex.deref(index_column), ex.const(key))
300+
windowed_block, match_col_id = windowed_block.project_expr(match_expr)
301+
302+
# Filter to only rows where the key matches
303+
filtered_block = windowed_block.filter_by_id(match_col_id)
304+
305+
# Check if key exists at all by counting on the filtered block
306+
count_agg = ex.UnaryAggregation(agg_ops.count_op, ex.deref(row_num_col_id.name))
307+
count_result = filtered_block._expr.aggregate([(count_agg, "count")])
308+
count_scalar = self._block.session._executor.execute(
309+
count_result
310+
).to_py_scalar()
311+
312+
if count_scalar == 0:
313+
raise KeyError(f"'{key}' is not in index")
314+
315+
# If only one match, return integer position
316+
if count_scalar == 1:
317+
min_agg = ex.UnaryAggregation(agg_ops.min_op, ex.deref(row_num_col_id.name))
318+
position_result = filtered_block._expr.aggregate([(min_agg, "position")])
319+
position_scalar = self._block.session._executor.execute(
320+
position_result
321+
).to_py_scalar()
322+
return int(position_scalar)
323+
324+
# Multiple matches - need to determine if monotonic or not
325+
is_monotonic = self.is_monotonic_increasing or self.is_monotonic_decreasing
326+
if is_monotonic:
327+
# Get min and max positions for slice
328+
min_agg = ex.UnaryAggregation(agg_ops.min_op, ex.deref(row_num_col_id.name))
329+
max_agg = ex.UnaryAggregation(agg_ops.max_op, ex.deref(row_num_col_id.name))
330+
min_result = filtered_block._expr.aggregate([(min_agg, "min_pos")])
331+
max_result = filtered_block._expr.aggregate([(max_agg, "max_pos")])
332+
min_pos = self._block.session._executor.execute(min_result).to_py_scalar()
333+
max_pos = self._block.session._executor.execute(max_result).to_py_scalar()
334+
335+
# create slice
336+
start = int(min_pos)
337+
stop = int(max_pos) + 1 # exclusive
338+
return slice(start, stop, None)
339+
else:
340+
# Return boolean mask for non-monotonic duplicates
341+
mask_block = windowed_block.select_columns([match_col_id])
342+
return bigframes.series.Series(mask_block)
343+
250344
def __repr__(self) -> str:
251345
# Protect against errors with uninitialized Series. See:
252346
# https://github.com/googleapis/python-bigquery-dataframes/issues/728

‎tests/system/small/test_index.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,106 @@ def test_index_construct_from_list():
3232
pd.testing.assert_index_equal(bf_result, pd_result)
3333

3434

35+
@pytest.mark.parametrize("key, expected_loc", [("a", 0), ("b", 1), ("c", 2)])
36+
def test_get_loc_should_return_int_for_unique_index(key, expected_loc):
37+
"""Behavior: get_loc on a unique index returns an integer position."""
38+
# The pandas result is used as the known-correct value.
39+
# We assert our implementation matches it and the expected type.
40+
bf_index = bpd.Index(["a", "b", "c"])
41+
42+
result = bf_index.get_loc(key)
43+
44+
assert result == expected_loc
45+
assert isinstance(result, int)
46+
47+
48+
def test_get_loc_should_return_slice_for_monotonic_duplicates():
49+
"""Behavior: get_loc on a monotonic string index with duplicates returns a slice."""
50+
bf_index = bpd.Index(["a", "b", "b", "c"])
51+
pd_index = pd.Index(["a", "b", "b", "c"])
52+
53+
bf_result = bf_index.get_loc("b")
54+
pd_result = pd_index.get_loc("b")
55+
56+
assert isinstance(bf_result, slice)
57+
assert bf_result == pd_result # Should be slice(1, 3, None)
58+
59+
60+
def test_get_loc_should_return_slice_for_monotonic_numeric_duplicates():
61+
"""Behavior: get_loc on a monotonic numeric index with duplicates returns a slice."""
62+
bf_index = bpd.Index([1, 2, 2, 3])
63+
pd_index = pd.Index([1, 2, 2, 3])
64+
65+
bf_result = bf_index.get_loc(2)
66+
pd_result = pd_index.get_loc(2)
67+
68+
assert isinstance(bf_result, slice)
69+
assert bf_result == pd_result # Should be slice(1, 3, None)
70+
71+
72+
def test_get_loc_should_return_mask_for_non_monotonic_duplicates():
73+
"""Behavior: get_loc on a non-monotonic string index returns a boolean array."""
74+
bf_index = bpd.Index(["a", "b", "c", "b"])
75+
pd_index = pd.Index(["a", "b", "c", "b"])
76+
77+
bf_result = bf_index.get_loc("b")
78+
if hasattr(bf_result, "to_numpy"):
79+
bf_array = bf_result.to_numpy()
80+
else:
81+
bf_array = bf_result.to_pandas().to_numpy()
82+
pd_result = pd_index.get_loc("b")
83+
84+
numpy.testing.assert_array_equal(bf_array, pd_result)
85+
86+
87+
def test_get_loc_should_return_mask_for_non_monotonic_numeric_duplicates():
88+
"""Behavior: get_loc on a non-monotonic numeric index returns a boolean array."""
89+
bf_index = bpd.Index([1, 2, 3, 2])
90+
pd_index = pd.Index([1, 2, 3, 2])
91+
92+
bf_result = bf_index.get_loc(2)
93+
if hasattr(bf_result, "to_numpy"):
94+
bf_array = bf_result.to_numpy()
95+
else:
96+
bf_array = bf_result.to_pandas().to_numpy()
97+
pd_result = pd_index.get_loc(2)
98+
99+
numpy.testing.assert_array_equal(bf_array, pd_result)
100+
101+
102+
def test_get_loc_should_raise_error_for_missing_key():
103+
"""Behavior: get_loc raises KeyError when a string key is not found."""
104+
bf_index = bpd.Index(["a", "b", "c"])
105+
106+
with pytest.raises(KeyError):
107+
bf_index.get_loc("d")
108+
109+
110+
def test_get_loc_should_raise_error_for_missing_numeric_key():
111+
"""Behavior: get_loc raises KeyError when a numeric key is not found."""
112+
bf_index = bpd.Index([1, 2, 3])
113+
114+
with pytest.raises(KeyError):
115+
bf_index.get_loc(4)
116+
117+
118+
def test_get_loc_should_work_for_single_element_index():
119+
"""Behavior: get_loc on a single-element index returns 0."""
120+
assert bpd.Index(["a"]).get_loc("a") == pd.Index(["a"]).get_loc("a")
121+
122+
123+
def test_get_loc_should_return_slice_when_all_elements_are_duplicates():
124+
"""Behavior: get_loc returns a full slice if all elements match the key."""
125+
bf_index = bpd.Index(["a", "a", "a"])
126+
pd_index = pd.Index(["a", "a", "a"])
127+
128+
bf_result = bf_index.get_loc("a")
129+
pd_result = pd_index.get_loc("a")
130+
131+
assert isinstance(bf_result, slice)
132+
assert bf_result == pd_result # Should be slice(0, 3, None)
133+
134+
35135
def test_index_construct_from_series():
36136
bf_result = bpd.Index(
37137
bpd.Series([3, 14, 159], dtype=pd.Float64Dtype(), name="series_name"),

‎third_party/bigframes_vendored/pandas/core/indexes/base.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,41 @@ def argmin(self) -> int:
741741
"""
742742
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
743743

744+
def get_loc(self, key):
745+
"""
746+
Get integer location, slice or boolean mask for requested label.
747+
748+
**Examples:**
749+
750+
>>> import bigframes.pandas as bpd
751+
>>> bpd.options.display.progress_bar = None
752+
753+
>>> unique_index = bpd.Index(list('abc'))
754+
>>> unique_index.get_loc('b')
755+
1
756+
757+
>>> monotonic_index = bpd.Index(list('abbc'))
758+
>>> monotonic_index.get_loc('b')
759+
slice(1, 3, None)
760+
761+
>>> non_monotonic_index = bpd.Index(list('abcb'))
762+
>>> non_monotonic_index.get_loc('b')
763+
array([False, True, False, True])
764+
765+
Args:
766+
key: Label to get the location for.
767+
768+
Returns:
769+
int if unique index, slice if monotonic index with duplicates, else boolean array:
770+
Integer position of the label for unique indexes.
771+
Slice object for monotonic indexes with duplicates.
772+
Boolean array mask for non-monotonic indexes with duplicates.
773+
774+
Raises:
775+
KeyError: If the key is not found in the index.
776+
"""
777+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
778+
744779
def argmax(self) -> int:
745780
"""
746781
Return int position of the largest value in the Series.

0 commit comments

Comments
 (0)