Beyond Static Heuristics: Implementing Adaptive Kernel Selection for High-Throughput GPU Inference

Title: Beyond Static Heuristics: Implementing Adaptive Kernel Selection for High-Throughput GPU Inference Slug: adaptive-kernel-selection-gpu-inference-optimization Category: AI Tools for Developers MetaDescription: Learn how to implement adaptive kernel selection to optimize GPU inference serving for dynamic workloads. Minimize latency and maximize TFLOPS.
Quick Summary
Adaptive Kernel Selection (AKS) is the process of dynamically choosing the most efficient GPU implementation (kernel) for a specific operation—like GEMM or FlashAttention—based on real-time input shapes and hardware state. Static autotuning often fails in production because batch sizes, sequence lengths, and KV-cache states fluctuate constantly. By implementing a lightweight dispatcher using a pre-profiled lookup table or a heuristic-based meta-scheduler, you can reclaim up to 15-30% of "lost" TFLOPS that standard library defaults usually leave on the table.
The Static Kernel Fallacy in Dynamic Inference
If you are running production inference, you’ve likely encountered this: your model benchmarks beautifully at a batch size of 32, but once it hits the "real world"—where traffic is bursty and requests are heterogeneous—latency spikes unexpectedly.
The culprit is often a static kernel choice. Most deep learning frameworks (PyTorch, TensorFlow) or runtimes (ONNX Runtime, TensorRT) perform a "warm-up" or autotuning phase where they select the best kernel for a specific input shape. However, in a dynamic serving environment, the shape $(B, S, H)$—Batch, Sequence, Hidden Dim—is a moving target. A kernel optimized for $B=1$ (latency-focused) is fundamentally different from one optimized for $B=64$ (throughput-focused).
I have seen teams spend months optimizing MoE models for efficient resource inference only to have their gains wiped out because they used a generic cuBLAS GEMM kernel that didn't account for the non-aligned memory access of small batch sizes.
To solve this, we need an Adaptive Kernel Dispatcher. This system evaluates the incoming tensor shapes and dispatches the request to the specific CUDA kernel implementation that maximizes hardware utilization for that exact footprint.
The Architecture of an Adaptive Dispatcher
Implementation happens in three distinct phases: the Profiling Sandbox, the Dispatcher Logic, and the Runtime Executor.
1. The Profiling Sandbox
You cannot guess performance. Performance is a function of the GPU’s SM (Streaming Multiprocessor) count, L2 cache size, and memory bandwidth. To build your dispatcher, you must first generate a performance map.
I recommend using NVIDIA CUTLASS for this. Unlike cuBLAS, which is a black box, CUTLASS allows you to tune tile sizes, alignment, and pipeline stages.
# Conceptual Python wrapper for profiling multiple CUTLASS configurations
import subprocess
import json
def profile_kernel_configs(m, n, k):
configs = [
{"tile_shape": "128x128x32", "stages": 3, "alignment": 8},
{"tile_shape": "64x64x32", "stages": 5, "alignment": 8},
{"tile_shape": "256x128x32", "stages": 2, "alignment": 4},
]
results = {}
for cfg in configs:
# We invoke the CUTLASS profiler or a custom CUDA runner
latency = run_cutlass_bench(m, n, k, cfg)
results[str(cfg)] = latency
return min(results, key=results.get)
# Example output: "For M=1, N=4096, K=4096, the 64x64x32 tile is 40% faster"
2. The Heuristic Dispatcher
Once you have your profile data, you don't want to run a search at runtime—that's too slow. Instead, you build a Lookup Table (LUT) or a decision tree.
For LLM workloads, the dispatcher usually keys off of the "Current Token Index" and the "Batch Size." If you are speeding up LLMs with speculative decoding, your batch size is effectively doubled (or more) because you are verifying multiple tokens at once. The kernel that handles the single-token generation (low M, high N/K) is rarely the same one that should handle the verification block (higher M).
3. The Runtime Executor
The executor needs to be low-latency. If your dispatch logic takes 500 microseconds but only saves 200 microseconds of kernel execution time, you’ve failed. This is why we implement the dispatcher in C++ or directly in a custom Triton kernel.
Step-by-Step: Implementing an AKS Layer in Triton
Triton is particularly well-suited for AKS because it allows for "Autotune" decorators that can be customized. However, the default @triton.autotune is often too broad. Here is how I implement a more granular selection for a dynamic GEMM operation.
Step 1: Define Your Configurations
Don't just sweep every possible variable. Focus on BLOCK_SIZE_M, BLOCK_SIZE_N, and num_warps.
import triton
import triton.language as tl
# Define a selection of kernels optimized for different shapes
def get_configs():
return [
triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32}, num_warps=4), # Small shapes
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_warps=4), # Medium
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_warps=8), # Large
]
@triton.autotune(
configs=get_configs(),
key=['M', 'N', 'K'], # The dispatcher checks these variables to select the config
)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
num_warps: tl.constexpr
):
# Kernel logic here...
pass
Step 2: Overriding the Heuristic
The "Gotcha" here is that Triton's autotuner caches the first result it finds for a given set of keys. If your workload varies wildly, the cache can become a bottleneck. I prefer to use a Meta-Dispatcher that bypasses the autotuner for edge cases (like extremely small batches during cold starts).
def smart_dispatch_matmul(A, B):
M, K = A.shape
K, N = B.shape
# Custom heuristic based on hardware analysis
if M < 4:
# Dispatch to a specialized "Vector-Matrix" kernel
return mv_kernel_specialized[grid](A, B, ...)
elif M % 16 != 0:
# Handle non-aligned shapes with a padding-aware kernel
return matmul_padded_kernel[grid](A, B, ...)
else:
# Standard optimized path
return matmul_kernel[grid](A, B, ...)
Critical Pitfall: The "Alignment Trap"
I have seen high-level engineers get stuck here frequently. Most optimized GPU kernels rely on Vectorized Loads (e.g., ld.global.v4.f32 in PTX). This requires the memory address to be aligned to 16 bytes.
When you are serving dynamic batch sizes, your memory offsets often break this alignment. If your kernel expects 16-byte alignment but receives 8-byte alignment, it will either:
- Crash (Illegal Memory Access).
- Fall back to scalar loads, which are significantly slower.
The Fix: Your adaptive selection logic must check addr % alignment == 0. If it's not, you must dispatch to a kernel that uses tl.load(..., boundary_check=(0, 1)) or manually handles the misalignment. This is a primary reason why optimizing mobile AI and NAS is so difficult—the hardware constraints are even more rigid.
Managing the Warm-up Overhead
Adaptive selection requires a "warm-up" for every new shape encountered. In a dynamic production environment, this can lead to "jitter." To mitigate this, I use two strategies:
- AOT (Ahead-of-Time) Binning: Round your dynamic shapes to the nearest power of 2 or multiple of 8. This reduces the number of unique "keys" in your lookup table, ensuring you hit cached kernels more often.
- Persistent Kernel Cache: Store your profiled LUT on disk. When the inference pod restarts, it loads the pre-profiled optimal configurations instead of re-running the search.
Quantifying the Gains: When is it worth it?
You shouldn't implement AKS for every single layer. Focus on the Bottleneck Ops:
- Linear Layers in the Attention Block: Since these scale with $B \times S$, they are prime candidates.
- Logit Computation: Usually highly memory-bound and sensitive to N-dimension alignment.
- KV-Cache Management: If you are implementing custom PagedAttention, the kernel that reads/writes to the cache should be adaptive based on the "fill percentage" of the blocks.
In my experience, if your GPU utilization (monitored via nvidia-smi dmon) fluctuates wildly while throughput remains stagnant, your kernels are likely poorly matched to your workload. Implementing AKS can stabilize utilization and drop tail latency (p99).
Common Gotchas
- L2 Cache Poisoning: A kernel that is fast in a micro-benchmark might be slow in production because it aggressively flushes the L2 cache, slowing down the next operation in the graph. Always profile with the preceding and following kernels in the pipeline.
- Occupancy vs. Throughput: A kernel with 100% occupancy isn't always the fastest. Sometimes a kernel with 50% occupancy but better memory instruction pipelining wins. Don't let the occupancy metric fool you; rely on execution time.
- CPU Overhead: If you write your dispatcher in Python, the overhead of the
if/elselogic and the kernel launch can exceed the execution time of the kernel itself for small operations. Keep the dispatcher logic in the "Fast Path" (C++ or TorchScript).
Practical FAQ
Q: How does this differ from CUDA Graphs? CUDA Graphs capture a sequence of kernels to minimize launch overhead. AKS is about choosing which kernel goes into that sequence. You can combine them: use AKS to determine the optimal graph structure during a "re-graphing" phase when batch sizes change significantly.
Q: Can I use this for FP8 or Int8 quantization? Absolutely. In fact, it’s more important there. Quantized kernels are extremely sensitive to alignment. An adaptive selection layer can decide between a "Vectorized INT4" kernel and a "Standard INT8" kernel based on whether the input features are sparse enough to justify the overhead of a specialized implementation.
Q: Does this replace TensorRT?
No, it complements it. TensorRT does a great job of static optimization. However, if you are building a custom serving stack or need to support shapes that TensorRT didn't see during the polygraphy or trtexec build phase, a custom AKS layer is your best bet for maintaining performance.
Next Steps
To get started, don't rewrite your whole stack. Pick your most expensive nn.Linear or Attention layer. Profile it across five batch sizes $(1, 4, 8, 16, 32)$ using three different tile sizes in Triton. Build a simple if/else dispatcher and measure the p99 latency. You'll likely see enough of a gain to justify expanding the system to the rest of your pipeline.
For those working on more complex setups, like multi-agent orchestration frameworks, the dynamic nature of agent-to-agent communication makes AKS almost mandatory to handle the unpredictable token volumes.

CyberInsist
Official blog of CyberInsist - Empowering you with technical excellence.
Continue Reading

Moving Beyond DPO: A Senior Engineer’s Guide to KTO vs. IPO for Production Preference Alignment
A deep technical comparison of KTO and IPO for LLM preference alignment. Learn how to handle unpaired production feedback and avoid DPO overfitting.
5 min read
Differential vs. Standard Softmax Attention: Engineering More Precise Long-Context Retrieval in Production
A deep technical dive into why Differential Attention solves the "noise" problem in long-context LLMs and how it compares to Standard Softmax in production
5 min read
Scaling Context to 1M+: Ring Attention vs. DeepSpeed Ulysses in Production
Deep technical comparison of Ring Attention and DeepSpeed Ulysses for long-context LLM training. Learn the performance trade-offs, bottlenecks, and impleme
5 min read