Problem Statement

Our initial GPU implementation was slower than CPU (10 iter/sec vs 26 iter/sec) due to GPU kernel launch overhead. Each EHS computation triggered ~300 small kernel launches at ~50μs each.

Solution Overview

Two complementary optimizations:

OptimizationPurposeSpeedup
BatchEHSCollectorProcess thousands of hands per GPU call76x
RiverEHSTableO(1) lookup for pre-computed EHSCache hits = zero GPU

Implementation

1. BatchEHSCollector

Collects EHS queries during tree traversal, computes all in one batch:

# OLD: Sequential (300 kernel launches per hand)
for hand in hands:
ehs = compute_ehs(hand.hole, hand.board) # 300 kernels
# NEW: Batched (1 kernel per batch of 10K hands)
collector = BatchEHSCollectorV2(rank_table, samples_per_hand=100)
for hand in hands:
collector.add_query(hand.hole, hand.board) # Just collect
collector.compute_all() # ONE big kernel launch
for hand in hands:
ehs = collector.get_result(hand.id) # O(1) retrieval

Location: src/cfr_poker/holdem/gpu/batch_collector.py

2. BatchEHSCollectorV2 (Fully Vectorized)

Uses Gumbel-max trick for vectorized opponent sampling:

def _sample_opponents_vectorized(self, avail_mask: torch.Tensor) -> torch.Tensor:
"""Sample using Gumbel-max trick for per-row sampling."""
noise = torch.rand((n, 52), device=self.device)
noise[~avail_mask] = -1e9 # Exclude dead cards
_, indices = noise.topk(2, dim=1) # Top-2 per row
return indices

Advantage: No Python loops during sampling - pure GPU tensor ops.

3. RiverEHSTable

Persistent cache for river EHS values:

class RiverEHSTable:
"""O(1) lookup for pre-computed EHS."""
def lookup(self, hole, board) -> Optional[float]:
"""Return cached EHS or None."""
def set(self, hole, board, ehs) -> None:
"""Cache new EHS value."""
def save(self, path) -> None:
"""Persist to disk for next session."""

Features:

  • FIFO eviction when max entries exceeded
  • Pickle serialization for persistence
  • Sorted tuple keys for order-independent lookup

Location: src/cfr_poker/holdem/gpu/river_ehs_table.py

Benchmark Results

Pre-computation Performance

MethodRateTime for 1M samples
GPU Sequential140/s7,140s (2 hours)
GPU Batched10,000+/s94s (1.5 min)
CPU~500/s~33 min

Speedup: 76x (sequential → batched)

Why Batching Works

Sequential: Batched:
┌─────────┐ ┌─────────────────────────────┐
│ Kernel 1│ 50μs │ │
├─────────┤ │ One kernel: 10K hands │
│ Kernel 2│ 50μs │ │
├─────────┤ │ Total: ~1ms │
│ ... │ │ │
├─────────┤ │ (amortized: 0.1μs/hand) │
│Kernel 300│ 50μs │ │
└─────────┘ └─────────────────────────────┘
Total: 15ms/hand Total: 0.1ms/hand

Key Insight: GPU kernel launch overhead (~50μs) is fixed per launch, not per operation. Batching amortizes this across thousands of hands.

Usage

Pre-compute EHS Cache

Terminal window
# Pre-compute 1M river EHS values (~94 seconds)
HSA_OVERRIDE_GFX_VERSION=11.0.0 python scripts/precompute_river_ehs.py \
--samples 1000000 \
--gpu \
--batch-size 10000 \
-v

Train with Pre-computed Cache

Terminal window
# Use pre-computed cache during training
python scripts/train_holdem.py \
--multistreet \
--precomputed-ehs \
--save-ehs-cache \
-i 10000

CLI Options

FlagDescription
--precomputed-ehsEnable pre-computed river EHS table
--ehs-cache-pathCustom cache file location
--save-ehs-cachePersist cache after training
--batch-sizeHands per batch (default: 10000)
--sequentialUse old sequential mode (for comparison)

Files Added/Modified

FilePurpose
gpu/batch_collector.pyBatchEHSCollector & V2 implementation
gpu/river_ehs_table.pyRiverEHSTable persistent cache
scripts/precompute_river_ehs.pyPre-computation script
abstraction.pyIntegration with PostflopAbstraction
mccfr.pyNew CLI parameters
train_holdem.pyCLI flags

Test Coverage

30 new GPU tests added:

  • test_batch_collector.py - 14 tests
  • test_river_ehs_table.py - 16 tests

All 221 tests pass.

Future Work

  1. Batched MCCFR: Apply batching during actual tree traversal (not just pre-computation)
  2. Turn/Flop Tables: Extend pre-computation to earlier streets
  3. Memory-Mapped Files: For larger caches that exceed RAM
  4. CUDA Comparison: Test on NVIDIA GPU for comparison

Conclusions

  1. Batching is essential for GPU speedup - individual operations are dominated by launch overhead
  2. Pre-computation + caching provides best of both worlds (GPU compute, O(1) runtime lookup)
  3. 76x speedup achieved by restructuring from per-hand to per-batch processing
  4. ROCm/AMD works with HSA_OVERRIDE_GFX_VERSION=11.0.0 workaround