-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[Data] Add constant folding optimization to logical plan optimizer #60635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[Data] Add constant folding optimization to logical plan optimizer #60635
Conversation
This PR implements compile-time **constant folding** and basic algebraic simplification for Ray Data logical expressions.
The optimization evaluates constant expressions during the logical planning phase rather than at execution time, which can significantly reduce runtime computation, simplify expression trees, and improve overall pipeline performance — especially for complex chained transformations and user-defined filters/column expressions.
Main optimizations supported:
- Pure constant folding: `lit(3) + lit(5)` → `lit(8)`
- Identity elimination: `col("x") * lit(1)` → `col("x")`, `col("x") + lit(0)` → `col("x")`
- Null & zero propagation: `lit(0) * col("x")` → `lit(0)`
- Boolean short-circuit & constant propagation:
- `lit(False) & anything` → `lit(False)`
- `lit(True) | anything` → `lit(True)`
- `~lit(True)` → `lit(False)`
- Simplification of nested expressions and complex predicates
Key implementation aspects:
- Introduced `ConstantFoldingRule` as a logical optimization rule
- Implemented expression rewriting via `_ConstantFoldingVisitor` (visitor pattern)
- Supports arbitrary nesting depth and most common expression types
- Multi-pass fixpoint iteration until no more changes occur
- Runs early in the logical optimizer pipeline (before most other rules)
Performance impact:
- Eliminates redundant runtime computation for constant sub-expressions
- Reduces expression evaluation cost, especially in map/filter/project heavy pipelines
- Shrinks logical plan tree size → better optimization opportunities for subsequent rules
Added / modified files:
- New: `python/ray/data/_internal/planner/plan_expression/constant_folder.py`
- New: `python/ray/data/_internal/logical/rules/constant_folding.py`
- New: `python/ray/data/tests/test_constant_folding.py`
- Modified: `python/ray/data/_internal/logical/rules/__init__.py`
- Modified: `python/ray/data/_internal/logical/optimizers.py`
- Modified: `python/ray/data/BUILD.bazel`
Signed-off-by: slfan1989 <slfan1989@apache.org>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a valuable constant folding optimization for Ray Data's logical optimizer. The implementation is well-structured, using a visitor pattern to traverse expression trees and apply simplifications. It correctly handles pure constant folding, algebraic simplifications, and short-circuit evaluation for boolean expressions. The multi-pass approach ensures that complex nested expressions are fully optimized. The addition of comprehensive unit and integration tests is also a great inclusion.
I've found one potential correctness issue with an algebraic simplification and a minor inconsistency in a test comment. Overall, this is a solid contribution that should improve query performance.
| elif op == Operation.MOD: | ||
| # x % 1 → 0 | ||
| if isinstance(right, LiteralExpr) and right.value == 1: | ||
| return LiteralExpr(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The algebraic simplification for the MOD operation, x % 1 → 0, is only correct for integer types. For floating-point types, x % 1 yields the fractional part of x (e.g., 1.5 % 1 is 0.5). Since the type of the left expression is not checked here, applying this simplification can lead to incorrect results for columns with floating-point data.
Given that type information on expressions may not be reliably available at this optimization stage, I recommend removing this simplification to prevent potential data corruption. There are no tests covering this case, which further suggests it might be an oversight.
| assert folded.value == 10 | ||
|
|
||
| def test_deeply_nested(self): | ||
| """Test 2*(3+(4*(5+6))) → lit(86)""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The expected result in the docstring is lit(86), but the correct calculation is 94, as reflected in the assertion on line 89 and the comment on line 88. The docstring should be updated to reflect the correct value to avoid confusion.
| """Test 2*(3+(4*(5+6))) → lit(86)""" | |
| """Test 2*(3+(4*(5+6))) → lit(94)""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cursor Bugbot has reviewed your changes and found 7 potential issues.
| return LiteralExpr(0) | ||
| # 0 * x → 0 | ||
| if isinstance(left, LiteralExpr) and left.value == 0: | ||
| return LiteralExpr(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Multiplication by zero incorrect for NaN and Infinity
High Severity
The algebraic simplification x * 0 → 0 and 0 * x → 0 violates IEEE 754 floating-point semantics. According to IEEE 754, NaN * 0 = NaN and Inf * 0 = NaN, not 0. When a column contains NaN or Infinity values, this optimization produces incorrect results, silently converting NaN/Inf to 0 at query time.
| # x != x → False | ||
| if isinstance(left, ColumnExpr) and isinstance(right, ColumnExpr): | ||
| if left.name == right.name: | ||
| return LiteralExpr(False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Column self-comparison incorrectly simplifies NaN values
High Severity
The simplifications col("x") == col("x") → True, col("x") != col("x") → False, and col("x") - col("x") → 0 violate IEEE 754. NaN is not equal to itself (NaN == NaN is False, NaN != NaN is True, and NaN - NaN is NaN). These optimizations produce incorrect results for columns containing NaN values.
Additional Locations (1)
| elif op == Operation.MOD: | ||
| # x % 1 → 0 | ||
| if isinstance(right, LiteralExpr) and right.value == 1: | ||
| return LiteralExpr(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return Filter( | ||
| op.input_dependency, | ||
| predicate_expr=folded_predicate, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Filter operator loses compute and ray_remote_args attributes
High Severity
When creating a new Filter operator after constant folding, the code only passes input_dependency and predicate_expr, losing any compute strategy and ray_remote_args that were set on the original operator. The Filter class accepts these parameters, and they control execution behavior. Compare to _fold_project_operator which correctly preserves op._compute and op._ray_remote_args.
| elif op == Operation.FLOORDIV: | ||
| # x // 1 → x (for integers) | ||
| if isinstance(right, LiteralExpr) and right.value == 1: | ||
| return left |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Floor division by one incorrect for floats
Medium Severity
The simplification x // 1 → x (annotated as "for integers" in the comment) is applied unconditionally but is only valid for integers. For floating-point numbers, x // 1 returns the floor of x (e.g., 2.5 // 1 = 2.0). This produces incorrect results for float columns.
| if not args_changed and not kwargs_changed: | ||
| return expr | ||
|
|
||
| return UDFExpr(expr.fn, *folded_args, **folded_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
UDFExpr reconstruction uses incorrect argument unpacking
High Severity
The UDFExpr dataclass constructor expects three arguments: fn, args (a List[Expr]), and kwargs (a Dict[str, Expr]). However, the code uses UDFExpr(expr.fn, *folded_args, **folded_kwargs), which incorrectly unpacks the list elements as separate positional arguments and dict entries as keyword arguments. This causes a TypeError when a UDFExpr with multiple arguments needs reconstruction after folding. The correct call is UDFExpr(expr.fn, folded_args, folded_kwargs).
| return LiteralExpr(False) | ||
| # False & x → False | ||
| if isinstance(left, LiteralExpr) and left.value is False: | ||
| return LiteralExpr(False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unreachable code duplicates short-circuit logic
Low Severity
In _try_algebraic_simplification, the checks for left operand being a boolean literal in AND/OR operations (lines 218-220: True & x, lines 224-226: False & x, lines 236-238: False | x, lines 242-244: True | x) are unreachable. The _try_short_circuit method runs first in visit_binary and handles all cases where left is a boolean literal for AND/OR, always returning non-None. These duplicate checks can never execute.


Description
This PR implements compile-time constant folding optimization for Ray Data's logical optimizer. The optimization evaluates constant expressions and applies algebraic simplifications during query planning rather than at execution time, reducing runtime computation overhead and simplifying expression trees.
What this PR does:
lit(3) + lit(5)→lit(8)col("x") * 1→col("x"),col("x") + 0→col("x")False & col("x")→lit(False),True | col("x")→lit(True)((lit(True) & col("a")) | lit(False))→col("a")Why this is needed:
Currently, Ray Data evaluates expressions like
lit(3) + lit(5)at runtime for every row, even though the result is constant. This PR moves such evaluations to query planning time, eliminating unnecessary runtime overhead and enabling better optimization opportunities for downstream rules.Implementation:
ConstantFoldingRuleas a new logical optimization rule_ConstantFoldingVisitorusing the visitor pattern for expression tree traversalRelated issues
Additional information
Files added:
python/ray/data/_internal/planner/plan_expression/constant_folder.pypython/ray/data/_internal/logical/rules/constant_folding.pypython/ray/data/tests/test_constant_folding.pyFiles modified:
python/ray/data/_internal/logical/rules/__init__.py- ExportConstantFoldingRulepython/ray/data/_internal/logical/optimizers.py- Add rule to optimizer pipelinepython/ray/data/BUILD.bazel- Add test configurationExample optimizations:
Before and after query planning:
Performance impact: