batched-ehs-optimization
Problem Statement
Our initial GPU implementation was slower than CPU (10 iter/sec vs 26 iter/sec) due to GPU kernel launch overhead. Each EHS computation triggered ~300 small kernel launches at ~50μs each.
Solution Overview
Two complementary optimizations:
| Optimization | Purpose | Speedup |
|---|---|---|
| BatchEHSCollector | Process thousands of hands per GPU call | 76x |
| RiverEHSTable | O(1) lookup for pre-computed EHS | Cache hits = zero GPU |
Implementation
1. BatchEHSCollector
Collects EHS queries during tree traversal, computes all in one batch:
# OLD: Sequential (300 kernel launches per hand)for hand in hands: ehs = compute_ehs(hand.hole, hand.board) # 300 kernels
# NEW: Batched (1 kernel per batch of 10K hands)collector = BatchEHSCollectorV2(rank_table, samples_per_hand=100)for hand in hands: collector.add_query(hand.hole, hand.board) # Just collectcollector.compute_all() # ONE big kernel launchfor hand in hands: ehs = collector.get_result(hand.id) # O(1) retrievalLocation: src/cfr_poker/holdem/gpu/batch_collector.py
2. BatchEHSCollectorV2 (Fully Vectorized)
Uses Gumbel-max trick for vectorized opponent sampling:
def _sample_opponents_vectorized(self, avail_mask: torch.Tensor) -> torch.Tensor: """Sample using Gumbel-max trick for per-row sampling.""" noise = torch.rand((n, 52), device=self.device) noise[~avail_mask] = -1e9 # Exclude dead cards _, indices = noise.topk(2, dim=1) # Top-2 per row return indicesAdvantage: No Python loops during sampling - pure GPU tensor ops.
3. RiverEHSTable
Persistent cache for river EHS values:
class RiverEHSTable: """O(1) lookup for pre-computed EHS."""
def lookup(self, hole, board) -> Optional[float]: """Return cached EHS or None."""
def set(self, hole, board, ehs) -> None: """Cache new EHS value."""
def save(self, path) -> None: """Persist to disk for next session."""Features:
- FIFO eviction when max entries exceeded
- Pickle serialization for persistence
- Sorted tuple keys for order-independent lookup
Location: src/cfr_poker/holdem/gpu/river_ehs_table.py
Benchmark Results
Pre-computation Performance
| Method | Rate | Time for 1M samples |
|---|---|---|
| GPU Sequential | 140/s | 7,140s (2 hours) |
| GPU Batched | 10,000+/s | 94s (1.5 min) |
| CPU | ~500/s | ~33 min |
Speedup: 76x (sequential → batched)
Why Batching Works
Sequential: Batched:┌─────────┐ ┌─────────────────────────────┐│ Kernel 1│ 50μs │ │├─────────┤ │ One kernel: 10K hands ││ Kernel 2│ 50μs │ │├─────────┤ │ Total: ~1ms ││ ... │ │ │├─────────┤ │ (amortized: 0.1μs/hand) ││Kernel 300│ 50μs │ │└─────────┘ └─────────────────────────────┘Total: 15ms/hand Total: 0.1ms/handKey Insight: GPU kernel launch overhead (~50μs) is fixed per launch, not per operation. Batching amortizes this across thousands of hands.
Usage
Pre-compute EHS Cache
# Pre-compute 1M river EHS values (~94 seconds)HSA_OVERRIDE_GFX_VERSION=11.0.0 python scripts/precompute_river_ehs.py \ --samples 1000000 \ --gpu \ --batch-size 10000 \ -vTrain with Pre-computed Cache
# Use pre-computed cache during trainingpython scripts/train_holdem.py \ --multistreet \ --precomputed-ehs \ --save-ehs-cache \ -i 10000CLI Options
| Flag | Description |
|---|---|
--precomputed-ehs | Enable pre-computed river EHS table |
--ehs-cache-path | Custom cache file location |
--save-ehs-cache | Persist cache after training |
--batch-size | Hands per batch (default: 10000) |
--sequential | Use old sequential mode (for comparison) |
Files Added/Modified
| File | Purpose |
|---|---|
gpu/batch_collector.py | BatchEHSCollector & V2 implementation |
gpu/river_ehs_table.py | RiverEHSTable persistent cache |
scripts/precompute_river_ehs.py | Pre-computation script |
abstraction.py | Integration with PostflopAbstraction |
mccfr.py | New CLI parameters |
train_holdem.py | CLI flags |
Test Coverage
30 new GPU tests added:
test_batch_collector.py- 14 teststest_river_ehs_table.py- 16 tests
All 221 tests pass.
Future Work
- Batched MCCFR: Apply batching during actual tree traversal (not just pre-computation)
- Turn/Flop Tables: Extend pre-computation to earlier streets
- Memory-Mapped Files: For larger caches that exceed RAM
- CUDA Comparison: Test on NVIDIA GPU for comparison
Conclusions
- Batching is essential for GPU speedup - individual operations are dominated by launch overhead
- Pre-computation + caching provides best of both worlds (GPU compute, O(1) runtime lookup)
- 76x speedup achieved by restructuring from per-hand to per-batch processing
- ROCm/AMD works with
HSA_OVERRIDE_GFX_VERSION=11.0.0workaround