Skip to content

Conversation

@aseembits93
Copy link
Contributor

πŸ“„ 1,089% (10.89x) speedup for BlackjackFunctional.observation in gymnasium/envs/tabular/blackjack.py

⏱️ Runtime : 104 milliseconds β†’ 8.75 milliseconds (best of 33 runs)

πŸ“ Explanation and details

The optimized code achieves an 1088% speedup through two key JAX-specific optimizations:

1. JIT Compilation (@jax.jit decorators)

  • Added @jax.jit to both usable_ace() and sum_hand() functions
  • JIT compilation transforms these functions into optimized XLA kernels that run much faster than Python interpretation
  • Eliminates Python function call overhead and enables aggressive compiler optimizations

2. Vectorized JAX Operations

  • Replaced jnp.count_nonzero(hand == 1) > 0 with jnp.any(hand == 1) - more direct and efficient for checking existence
  • Replaced Python's sum(hand) with jnp.sum(hand) - keeps computation entirely within JAX's accelerated framework
  • Eliminated mixed Python/JAX operations that prevent full optimization

Performance Impact by Test Case:

  • Standard hands (2-6 cards): 70-90% speedup - JIT compilation overhead amortized quickly
  • Large hands (100+ cards): 1000-8000% speedup - vectorized jnp.sum and jnp.any scale much better than Python loops and count_nonzero
  • Maximum benefit: 999-1000 card hands see 8000%+ improvement because JAX operations are fully parallelized while the original code had Python bottlenecks

The line profiler shows the optimized version spends much less time in sum_hand() calls (50.6% vs 87.1%), indicating the JIT-compiled functions execute significantly faster than their interpreted counterparts.

βœ… Correctness verification report:

Test Status
βš™οΈ Existing Unit Tests πŸ”˜ None Found
πŸŒ€ Generated Regression Tests βœ… 20 Passed
βͺ Replay Tests πŸ”˜ None Found
πŸ”Ž Concolic Coverage Tests πŸ”˜ None Found
πŸ“Š Tests Coverage 100.0%
πŸŒ€ Generated Regression Tests and Runtime
from typing import TypeAlias

# function to test
import jax
import jax.numpy as jnp
import numpy as np
# imports
import pytest  # used for our unit tests
from gymnasium import spaces
from gymnasium.envs.tabular.blackjack import BlackjackFunctional
from gymnasium.experimental.functional import FuncEnv
from gymnasium.vector import AutoresetMode


# Minimal stub for BlackJackParams (not used in observation)
class BlackJackParams:
    pass

# Minimal stub for EnvState used in observation
class EnvState:
    def __init__(self, player_hand, dealer_hand):
        # player_hand and dealer_hand should be jax arrays for compatibility
        self.player_hand = jnp.array(player_hand)
        self.dealer_hand = jnp.array(dealer_hand)
from gymnasium.envs.tabular.blackjack import BlackjackFunctional


# Helper function for test readability
def get_obs(player_hand, dealer_hand):
    state = EnvState(player_hand, dealer_hand)
    rng = jax.random.PRNGKey(0)
    return BlackjackFunctional.observation(state, rng)

# -------------------------
# Basic Test Cases
# -------------------------






























#------------------------------------------------
from typing import TypeAlias

# function to test
import jax
import jax.numpy as jnp
import numpy as np
# imports
import pytest  # used for our unit tests
from gymnasium import spaces
from gymnasium.envs.tabular.blackjack import BlackjackFunctional
from gymnasium.experimental.functional import FuncEnv
from gymnasium.vector import AutoresetMode

PRNGKeyType: TypeAlias = jax.Array

# Minimal EnvState and BlackJackParams for testing
class EnvState:
    def __init__(self, player_hand, dealer_hand):
        self.player_hand = player_hand
        self.dealer_hand = dealer_hand

class BlackJackParams:
    pass  # Placeholder for compatibility
from gymnasium.envs.tabular.blackjack import BlackjackFunctional

# Instantiate the class for testing
env = BlackjackFunctional()

# Helper for making a jax PRNG key (not used in observation, but required by signature)
def dummy_rng():
    return jax.random.PRNGKey(0)

# ------------------------
# Basic Test Cases
# ------------------------

def test_observation_basic_no_ace():
    # Player: 10, 7 (sum=17), Dealer: 8
    state = EnvState(jnp.array([10, 7]), jnp.array([8, 5]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 948ΞΌs -> 526ΞΌs (80.0% faster)

def test_observation_basic_usable_ace():
    # Player: Ace, 7 (sum=18 with usable ace), Dealer: 10
    state = EnvState(jnp.array([1, 7]), jnp.array([10, 4]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 805ΞΌs -> 466ΞΌs (72.7% faster)

def test_observation_basic_nonusable_ace():
    # Player: Ace, 10, 5 (Ace not usable, sum=16), Dealer: 9
    state = EnvState(jnp.array([1, 10, 5]), jnp.array([9, 2]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 833ΞΌs -> 449ΞΌs (85.5% faster)

def test_observation_basic_multiple_aces_one_usable():
    # Player: Ace, Ace, 9 (sum=21, one ace usable), Dealer: 2
    state = EnvState(jnp.array([1, 1, 9]), jnp.array([2, 10]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 814ΞΌs -> 435ΞΌs (87.1% faster)

def test_observation_basic_multiple_aces_none_usable():
    # Player: Ace, Ace, 10, 10 (sum=22, both aces not usable), Dealer: 7
    state = EnvState(jnp.array([1, 1, 10, 10]), jnp.array([7, 3]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 851ΞΌs -> 438ΞΌs (94.2% faster)

# ------------------------
# Edge Test Cases
# ------------------------

def test_observation_edge_minimum_player_sum():
    # Player: Ace only (sum=11, usable ace), Dealer: Ace
    state = EnvState(jnp.array([1]), jnp.array([1, 8]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 727ΞΌs -> 429ΞΌs (69.2% faster)

def test_observation_edge_maximum_player_sum():
    # Player: 10, 10, 2 (sum=22), Dealer: 10
    state = EnvState(jnp.array([10, 10, 2]), jnp.array([10, 5]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 796ΞΌs -> 431ΞΌs (84.6% faster)

def test_observation_edge_blackjack_natural():
    # Player: Ace, 10 (sum=21, usable ace), Dealer: 10
    state = EnvState(jnp.array([1, 10]), jnp.array([10, 7]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 753ΞΌs -> 428ΞΌs (75.6% faster)

def test_observation_edge_dealer_showing_ace():
    # Player: 5, 6 (sum=11), Dealer: Ace
    state = EnvState(jnp.array([5, 6]), jnp.array([1, 8]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 750ΞΌs -> 429ΞΌs (74.8% faster)

def test_observation_edge_dealer_showing_ten():
    # Player: 9, 8 (sum=17), Dealer: 10
    state = EnvState(jnp.array([9, 8]), jnp.array([10, 2]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 745ΞΌs -> 422ΞΌs (76.3% faster)

def test_observation_edge_all_face_cards():
    # Player: 10, 10, 10 (sum=30), Dealer: 10
    state = EnvState(jnp.array([10, 10, 10]), jnp.array([10, 10]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 795ΞΌs -> 429ΞΌs (85.0% faster)

def test_observation_edge_all_aces():
    # Player: Ace, Ace, Ace (sum=13, one ace usable), Dealer: Ace
    state = EnvState(jnp.array([1, 1, 1]), jnp.array([1, 2]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 791ΞΌs -> 429ΞΌs (84.3% faster)

def test_observation_edge_no_cards_player():
    # Player: [], Dealer: 5 (invalid but should not crash, sum=0)
    state = EnvState(jnp.array([]), jnp.array([5, 6]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 654ΞΌs -> 461ΞΌs (42.0% faster)

def test_observation_edge_no_cards_dealer():
    # Player: 5, 5, Dealer: [] (invalid but should not crash, dealer showing=0)
    state = EnvState(jnp.array([5, 5]), jnp.array([]))
    # Should raise IndexError as dealer_hand[0] is out of bounds
    with pytest.raises(IndexError):
        env.observation(state, dummy_rng()) # 409ΞΌs -> 193ΞΌs (112% faster)

def test_observation_edge_large_card_values():
    # Player: 11, 12 (non-standard, sum=23), Dealer: 13 (non-standard)
    state = EnvState(jnp.array([11, 12]), jnp.array([13, 2]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 762ΞΌs -> 434ΞΌs (75.5% faster)

# ------------------------
# Large Scale Test Cases
# ------------------------

def test_observation_large_player_hand():
    # Player: 100 cards of value 1 (Aces), Dealer: 10
    player_hand = jnp.ones(100, dtype=jnp.int32)
    state = EnvState(player_hand, jnp.array([10, 2]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 5.38ms -> 483ΞΌs (1014% faster)

def test_observation_large_player_hand_with_one_usable_ace():
    # Player: 1, 2, 3, 4, 5, 6 (sum=21, usable ace)
    player_hand = jnp.array([1,2,3,4,5,6])
    state = EnvState(player_hand, jnp.array([7, 8]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 934ΞΌs -> 443ΞΌs (111% faster)

def test_observation_large_dealer_hand():
    # Dealer: 100 cards, first is 5, rest are 2s, Player: 10, 7
    dealer_hand = jnp.concatenate([jnp.array([5]), jnp.full(99, 2)])
    state = EnvState(jnp.array([10, 7]), dealer_hand)
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 768ΞΌs -> 424ΞΌs (80.8% faster)

def test_observation_large_random_hands():
    # Player: 999 cards, random values 1-10; Dealer: 10, 7
    rng = np.random.default_rng(42)
    player_hand = jnp.array(rng.integers(1, 11, size=999))
    state = EnvState(player_hand, jnp.array([10, 7]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 43.2ms -> 495ΞΌs (8612% faster)

def test_observation_performance_large_hands():
    # Player: 1000 cards, all 2s; Dealer: 10, 2
    player_hand = jnp.full(1000, 2)
    state = EnvState(player_hand, jnp.array([10, 2]))
    codeflash_output = env.observation(state, dummy_rng()); obs = codeflash_output # 42.4ms -> 498ΞΌs (8389% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-BlackjackFunctional.observation-mdzuyjzp and push.

Codeflash

The optimized code achieves an 1088% speedup through two key JAX-specific optimizations:

**1. JIT Compilation (@jax.jit decorators)**
- Added `@jax.jit` to both `usable_ace()` and `sum_hand()` functions
- JIT compilation transforms these functions into optimized XLA kernels that run much faster than Python interpretation
- Eliminates Python function call overhead and enables aggressive compiler optimizations

**2. Vectorized JAX Operations**
- Replaced `jnp.count_nonzero(hand == 1) > 0` with `jnp.any(hand == 1)` - more direct and efficient for checking existence
- Replaced Python's `sum(hand)` with `jnp.sum(hand)` - keeps computation entirely within JAX's accelerated framework
- Eliminated mixed Python/JAX operations that prevent full optimization

**Performance Impact by Test Case:**
- **Standard hands (2-6 cards)**: 70-90% speedup - JIT compilation overhead amortized quickly
- **Large hands (100+ cards)**: 1000-8000% speedup - vectorized `jnp.sum` and `jnp.any` scale much better than Python loops and `count_nonzero`
- **Maximum benefit**: 999-1000 card hands see 8000%+ improvement because JAX operations are fully parallelized while the original code had Python bottlenecks

The line profiler shows the optimized version spends much less time in `sum_hand()` calls (50.6% vs 87.1%), indicating the JIT-compiled functions execute significantly faster than their interpreted counterparts.
@aseembits93
Copy link
Contributor Author

@pseudo-rnd-thoughts would love to know your thoughts

Copy link
Member

@pseudo-rnd-thoughts pseudo-rnd-thoughts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for spotting the problems, however, we intentionally didn't jax.jit the function, as that is done separately. Could you remove that? Then I'll merge

@pseudo-rnd-thoughts pseudo-rnd-thoughts changed the title ⚑️ Speed up method BlackjackFunctional.observation by 1,089% Optimise the BlackjackFunctional usable_ace and sum_hand functions Aug 11, 2025
@aseembits93
Copy link
Contributor Author

Hi @pseudo-rnd-thoughts ! I undid the jit and ran the benchmark on 20 random episodes to get a ~20% speedup, I got the same speedup for up to 100 episodes. Here's the benchmark code - https://gist.github.com/aseembits93/2c0b232e93498a14e7686b0c8411ed10

Hope to hear back soon! Best,

@pseudo-rnd-thoughts
Copy link
Member

Amazing, thanks @aseembits93

@pseudo-rnd-thoughts pseudo-rnd-thoughts merged commit e401d51 into Farama-Foundation:main Aug 12, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants