Support stateful ops in TransformIterator (#9627)#9630
Conversation
A stateful op (a callable capturing a CUDA device array) works as a direct algorithm op but failed inside a TransformIterator used as an algorithm input: without a return annotation it raised "get_return_type not implemented for _StatefulOp", and with one it crashed at launch with cudaErrorLaunchFailure. The op's compiled device function takes its packed state pointers as a leading argument, but the iterator's generated dereference called it as a stateless op and never carried the op's state. Implement _StatefulOp.get_return_type, and compose the op's state after the underlying iterator's state (keeping the underlying at offset 0) so the dereference can pass the op its state pointer, mirroring PermutationIterator's state composition. Closes NVIDIA#9627 Signed-off-by: nethum529 <nethumweerasinghe.nw@gmail.com>
📝 WalkthroughSummary by CodeRabbit
WalkthroughRefactors stateful-op return-type inference in ChangesStateful TransformIterator Support
Assessment against linked issues
Out-of-scope changesNo out-of-scope changes identified. important: In suggestion: suggestion: Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
python/cuda_cccl/cuda/compute/_jit.py (1)
846-850: 🎯 Functional Correctness | 🔵 Trivial | ⚡ Quick winsuggestion: Captured state arrays are always rebuilt as flat 1-D buffers (
shape=len(state_array),strides=itemsize). Reject non-1D CUDA arrays here or document the 1-D-only contract; otherwise a multidimensional capture compiles against the wrong parameter shape.Source: Path instructions
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: e71a519d-89cb-4d9a-9d9d-091d49fcff60
📒 Files selected for processing (3)
python/cuda_cccl/cuda/compute/_jit.pypython/cuda_cccl/cuda/compute/iterators/_transform.pypython/cuda_cccl/tests/compute/test_reduce.py
|
The failing |
Description
closes #9627
A stateful op — a Python callable that closes over a CUDA device array — works as a direct algorithm op (e.g.
select'scond,unary_transform'sop), but failed when wrapped in aTransformIteratorused as an algorithm's input (reduce_into,segmented_reduce,histogram_even). Two symptoms:NotImplementedError: get_return_type not implemented for _StatefulOp.cudaErrorLaunchFailure: unspecified launch failure.Root cause
_StatefulOpnever implementedget_return_type, soTransformIteratorcould not infer the transformed value type (symptom 1).(void* op_state, void* input, void* output), whereop_statepoints to the packed device-array pointers. ButTransformIterator's generated dereference code declared and called the op as if it were stateless —(void* input, void* output)— and never carried the op's state into the iterator's state. On device, the op read garbage in place of its state pointers, crashing the launch (symptom 2).Fix
_jit.py: implement_StatefulOp.get_return_type, reusing the same Numba inference path as compilation (extracted into the helpers_state_array_numba_typesand_infer_stateful_return_type, now shared with_compile_stateful_op).iterators/_transform.py: for a stateful transform op, append the op's state bytes after the underlying iterator's state (with alignment padding; the underlying state stays at offset 0 where its child ops expect it) and passstatic_cast<char*>(state) + op_state_offsetas the op's state argument in the generated input/output dereference. This mirrors the existing state-composition pattern inPermutationIterator/compose_iterator_states. Stateless ops are unchanged.This makes index-gather and multi-array transforms fuse directly into reductions/histograms without first materializing an intermediate array.
Tests
Added to
tests/compute/test_reduce.py:TransformOutputIterator.Verified on an NVIDIA RTX 4070 (CUDA 13.3): the 6 new tests pass, and the reduce/select/segmented_reduce/histogram/iterators/transform/permutation/zip suites pass with no regressions.
Checklist