A JAX-based framework for multi-agent reinforcement learning for high-frequency trading, based on the JAX-LOB simulator and an extension of JaxMARL to the financial trading domain.
- GPU-Accelerated: Built on JAX for high-performance parallel computation with JIT compilation
- Two levels of Parallelization: Parallel processing across episodes and agent types using
vmap - Multi-Agent RL: Supports market making, execution, and directional trading agents
- LOBSTER Data Integration: Real market data support with efficient GPU memory usage
- Scalable: Handles thousands of parallel environments
- Heterogeneous Agents: Supports different observation/action spaces
# Set up data directory
mkdir -p ~/dataNote: Configure the Makefile for your specific environment (GPU device, data directory path, etc.)
# Build and run with Docker
make build
make run# Run IPPO training
python3 gymnax_exchange/jaxrl/MARL/ippo_rnn_JAXMARL.py- Purpose: Provide liquidity by posting bid/ask orders
- Action Spaces: Multiple discrete action spaces (spread_skew, fixed_quants, AvSt, directional_trading, simple)
- Reward Functions: Various PnL-based rewards with configurable inventory penalties
- Purpose: Execute large orders with minimal market impact
- Action Spaces: Discrete quantity selection at reference prices (fixed_quants, fixed_prices, complex variants)
- Reward Functions: Slippage-based with configurable end-of-episode penalties
- Purpose: Simple directional trading strategy
- Action Spaces: Bid/ask at best prices or no action
- Reward Function: Portfolio value
- Note: Uses the same class as the market making agent
gymnax_exchange/
├── jaxen/ # Environment implementations
│ ├── marl_env.py # Multi-agent RL environment
│ ├── mm_env.py # Market making (and directional trading) environment
│ └── exec_env.py # Execution environment
├── jaxrl/ # Reinforcement learning algorithms
│ └── MARL/ # IPPO implementation
├── jaxob/ # Order book implementation
└── jaxlobster/ # LOBSTER data integration
The framework uses a comprehensive configuration system with dataclasses for different components:
MultiAgentConfig: Main configuration combining world and agent settingsWorld_EnvironmentConfig: Global environment parameters (data paths, episode settings, market hours)MarketMaking_EnvironmentConfig: Market making and directional trading agent configuration (action spaces, reward functions, observation spaces)Execution_EnvironmentConfig: Execution agent configuration (task types, action spaces, reward parameters)
Edit YAML files in gymnax_exchange/jaxrl/MARL/config/ to customize:
- Number of parallel environments (default: 4096)
- Training parameters (steps, learning rates, etc.)
- Agent configurations (action spaces, reward functions)
- Market data settings (resolution, episode length)
- Python 3.8+
- CUDA-compatible GPU (recommended)
- JAX, Flax, and related dependencies (see
requirements.txt)