Llama-3.2-1B-MTP-k8: Multi-Token Prediction on Consumer Hardware
This model is a 1 billion parameter Llama-3.2 variant that implements Multi-Token Prediction (MTP), a technique designed to accelerate language model inference. Unlike standard autoregressive models that predict one token at a time, MTP trains the model to predict several future tokens concurrently through online self-distillation.
Key Features and Capabilities
- Accelerated Inference: Achieves faster generation throughput by predicting multiple tokens per step, demonstrated to be 1.8x faster on GSM8K with ConfAdapt decoding.
- ConfAdapt Decoding: Utilizes a confidence-adaptive decoding strategy that emits multiple tokens when confident and falls back to single-token prediction when uncertain, preserving generation quality.
- Consumer Hardware Optimization: This specific reproduction was scaled down and successfully trained on a single NVIDIA RTX 5090 (32GB), making MTP accessible on consumer-grade GPUs.
- Self-Distillation: Employs a frozen teacher model to generate soft probability distributions, which a trainable student model learns from to predict k future tokens.
Performance Highlights
On the GSM8K 8-shot Chain-of-Thought benchmark, the MTP k=8 + ConfAdapt 90% configuration achieved a throughput of 1.3 seconds per sample, compared to ~2.4 seconds per sample for MTP k=1, while maintaining comparable exact match scores (5.08% vs 5.23%). A small quality drop (2% accuracy) compared to the baseline is noted, consistent with the smaller model size and training data used in this reproduction.
Training Details
The model was trained for approximately 17 hours on a single RTX 5090 using the MetaMathQA dataset, with a fixed prediction horizon of k=8 tokens. This differs from the original paper's randomized k values across multiple GPUs, specializing this model for k=8 prediction.