-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Fix KTOTrainer CUDA error for large-vocab models via tensor indexing #4635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix KTOTrainer CUDA error for large-vocab models via tensor indexing #4635
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR addresses a CUDA error that occurs when training KTOTrainer with models that have extremely large vocabularies (e.g., Qwen3-VL with ~151k vocab size). The fix converts Python list-based fancy indexing to tensor-based indexing operations to prevent invalid CUDA kernel configurations.
Key Changes
- Replaced Python list fancy indexing with
torch.tensor()conversion +index_select()in theforwardmethod - Applied the same fix to the
get_batch_loss_metricsmethod when using pre-computed reference log probabilities - Ensures indices are on the correct device for CUDA operations
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
Very good fix, thanks. I was able to reproduce locally both the bug and the fix |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This PR resolves a CUDA error in KTOTrainer during model training with extremely large vocabularies (e.g., Qwen3-VL ~151k vocab). Hence, it opens up the possibility of KTO training on multimodal models having extended vocabularies like Qwen3-VL or alike.
This problem was initially seen downstream in Unsloth while performing KTO training. After a detailed study, the source was found to be Hugging Face TRL implementation. This PR corrects the root cause upstream so that local patches are no longer necessary for downstream projects (e.g., Unsloth). Tested successfully in Unsloth with Qwen3-VL after this fix.
The error resulted from Python list-based fancy indexing of tensors having shapes similar to:
[batch_size, seq_len, vocab_size]
For extremely large vocab sizes, this might lead to:
CUDA error: invalid configuration argument
Downstream issue reference:
unslothai/unsloth#3675
Root Cause
In trl/trainer/kto_trainer.py, Python lists are utilized to split the batch and these lists are simultaneously employed for fancy indexing on big tensors:
A similar pattern can be found in the reference_logps branch of get_batch_loss_metrics(...).
Fix
Python list indexing is substituted with device-aware tensor indices and torch.index_select that calls the CUDA kernels that are optimized.
Why This Works
- index_select uses optimized CUDA kernels
- indices are on the correct device
- avoids Python list fancy indexing
- no CPU-GPU sync overhead
- identical behavior for normal vocab sizes
- robust for 150k+ vocab models
All existing logic paths remain unchanged for normal vocab scenarios.Testing Unit Tests
Executed:
python3.10 -m pytest tests/test_kto_trainer.py -v
All KTO tests pass:
- TestKTOTrainer::test_kto_trainer[...]
- test_kto_trainer_with_ref_model_is_model
- test_tokenize_and_process_tokens
- test_kto_trainer_without_providing_ref_model
- test_kto_trainer_generate_during_eval_no_wandb
- test_compute_metrics
(LoRA / liger tests skipped as per upstream config.)Large-Vocab Simulation
Dummy tensor was created:
[4, 256, 151_936]
Compared:
old: tensor[list, ...]
new: tensor.index_select(0, idx_tensor)
Verification:
torch.allclose(old, new) == True
shapes match
values match
So behavior is unchanged — only the indexing method is safer.
Backward Compatibility
- no API changes
- no behavior changes for normal vocab sizes
- only internal indexing logic updated
- fixes KTO training for large vocab models (e.g. Qwen family)
ImpactThis allows KTO training for multilingual and multimodal models with extended vocabularies without triggering CUDA kernel launch errors. Downstream frameworks like Unsloth will get this fix automatically when they update TRL; hence, they don't need local