Refactor SamplerLoop class into functional API#550
Open
KumarADITHYA123 wants to merge 1 commit intogoogle-deepmind:mainfrom
Open
Refactor SamplerLoop class into functional API#550KumarADITHYA123 wants to merge 1 commit intogoogle-deepmind:mainfrom
KumarADITHYA123 wants to merge 1 commit intogoogle-deepmind:mainfrom
Conversation
Convert SamplerLoop class to standalone functions following JAX functional programming patterns. This addresses the TODO comment in _sampler_loop.py. Changes: - Replace SamplerLoop class with SamplerConfig frozen dataclass - Implement autoregressive_sample() for non-streaming sampling - Implement autoregressive_stream_sample() for streaming sampling - Refactor _sample_step() to module-level function - Update _sampler.py to use new functional API The refactoring improves code maintainability by separating configuration from behavior while preserving all existing functionality and maintaining full backwards compatibility with the public API.
424ffc4 to
a5326ac
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
closes issue #549
This addresses the TODO in _sampler_loop.py by converting the SamplerLoop class into standalone functions that better fit JAX's functional programming style.
The SamplerLoop class was basically just a container for static config with methods attached. I've split this into:
SamplerConfig: frozen dataclass holding the configuration
autoregressive_sample(): JIT-compiled function for regular sampling
autoregressive_stream_sample(): generator function for streaming
_sample_step(): helper function (was a method, now standalone)
Updated _sampler.py to use the new API. The caller now creates a SamplerConfig and calls the appropriate function instead of instantiating SamplerLoop and calling methods on it.
No functional changes. All the sampling logic is identical, just reorganized to be more functional. The public API (Sampler, ChatSampler, ToolSampler) is completely unchanged.
I've verified the syntax and logic preservation locally, but can't run the full test suite due to Python 3.14 environment issues. The CI should handle that when it runs in Python 3.11.