Optimise the BlackjackFunctional usable_ace and sum_hand functions
#1426
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
π 1,089% (10.89x) speedup for
BlackjackFunctional.observationingymnasium/envs/tabular/blackjack.pyβ±οΈ Runtime :
104 millisecondsβ8.75 milliseconds(best of33runs)π Explanation and details
The optimized code achieves an 1088% speedup through two key JAX-specific optimizations:
1. JIT Compilation (@jax.jit decorators)
@jax.jitto bothusable_ace()andsum_hand()functions2. Vectorized JAX Operations
jnp.count_nonzero(hand == 1) > 0withjnp.any(hand == 1)- more direct and efficient for checking existencesum(hand)withjnp.sum(hand)- keeps computation entirely within JAX's accelerated frameworkPerformance Impact by Test Case:
jnp.sumandjnp.anyscale much better than Python loops andcount_nonzeroThe line profiler shows the optimized version spends much less time in
sum_hand()calls (50.6% vs 87.1%), indicating the JIT-compiled functions execute significantly faster than their interpreted counterparts.β Correctness verification report:
π Generated Regression Tests and Runtime
To edit these changes
git checkout codeflash/optimize-BlackjackFunctional.observation-mdzuyjzpand push.