Глава 59: Grouped Query Attention для алгоритмического трейдинга
В этой главе рассматривается Grouped Query Attention (GQA) — эффективный механизм внимания, который обеспечивает оптимальный баланс между Multi-Head Attention (MHA) и Multi-Query Attention (MQA). Мы применяем GQA к прогнозированию финансовых временных рядов, демонстрируя, как его эффективность позволяет ускорить инференс для продакшн торговых систем.
Содержание
- Введение в Grouped Query Attention
- Алгоритм GQA
- Применение в трейдинге
- Практические примеры
- Реализация на Python
- Реализация на Rust
- Бенчмарки производительности
- Лучшие практики
- Ресурсы
Введение в Grouped Query Attention
Grouped Query Attention (GQA) был представлен Ainslie et al. (2023) как метод балансировки качества Multi-Head Attention со скоростью Multi-Query Attention. Вместо разделения ключей и значений между всеми головами запросов (MQA) или наличия отдельных K/V для каждой головы (MHA), GQA группирует головы запросов для совместного использования K/V проекций.
Проблема узкого места при инференсе
Во время авторегрессионного инференса (генерация по одному токену за раз) Key-Value (KV) кэш становится значительным узким местом:
Узкое место памяти при инференсе:+------------------------------------------------------------------------------+| || Размер KV кэша Multi-Head Attention (MHA): || ----------------------------------------------- || batch_size x seq_len x n_heads x head_dim x 2 (K и V) || || Пример (стиль Llama-2 7B): || - n_heads = 32 || - head_dim = 128 || - seq_len = 4096 || - batch_size = 8 || || KV кэш = 8 x 4096 x 32 x 128 x 2 = 268 МБ на слой || Для 32 слоёв = 8.6 ГБ только для KV кэша! || |+------------------------------------------------------------------------------+Для торговых систем быстрый инференс критичен:
- Маркет-мейкинг: Требуются решения за доли миллисекунды
- Арбитраж: Возможности исчезают за микросекунды
- Риск в реальном времени: Непрерывный мониторинг позиций
- Мультиактивный анализ: Много инструментов одновременно
MHA vs MQA vs GQA
Сравнение вариантов внимания:+------------------------------------------------------------------------------+| || Multi-Head Attention (MHA): || +--------+--------+--------+--------+ || | Q1 | Q2 | Q3 | Q4 | <- 4 головы Query || +--------+--------+--------+--------+ || | K1 | K2 | K3 | K4 | <- 4 головы Key (отдельные) || +--------+--------+--------+--------+ || | V1 | V2 | V3 | V4 | <- 4 головы Value (отдельные) || +--------+--------+--------+--------+ || Качество: Отличное | Память: 4x | Скорость: Базовая || || Multi-Query Attention (MQA): || +--------+--------+--------+--------+ || | Q1 | Q2 | Q3 | Q4 | <- 4 головы Query || +--------+--------+--------+--------+ || | K (общий) | <- 1 голова Key (общая) || +--------+--------+--------+--------+ || | V (общий) | <- 1 голова Value (общая) || +--------+--------+--------+--------+ || Качество: Хорошее (некоторая деградация) | Память: 1x | Скорость: 4x || || Grouped Query Attention (GQA с 2 группами): || +--------+--------+--------+--------+ || | Q1 | Q2 | Q3 | Q4 | <- 4 головы Query || +--------+--------+--------+--------+ || | K1 | K1 | K2 | K2 | <- 2 головы Key (сгруппированные) || +--------+--------+--------+--------+ || | V1 | V1 | V2 | V2 | <- 2 головы Value (сгруппированные) || +--------+--------+--------+--------+ || Качество: Очень хорошее | Память: 2x | Скорость: 2x быстрее || |+------------------------------------------------------------------------------+Преимущества для торговых моделей
| Преимущество | MHA | MQA | GQA | Влияние на трейдинг |
|---|---|---|---|---|
| Качество | Лучшее | Хорошее | Очень хорошее | GQA сохраняет точность прогнозов |
| Скорость инференса | 1x | 4-8x | 2-4x | Быстрее решения в реальном времени |
| Размер KV кэша | Полный | 1/H | G/H | Меньше памяти = больше символов |
| Размер батча | Ограничен | Большой | Средний | Лучшая пропускная способность |
| Задержка | Высокая | Низкая | Средне-низкая | Подходит для HFT |
Где H = количество голов, G = количество групп.
Алгоритм GQA
Обзор Multi-Head Attention
Стандартный Multi-Head Attention вычисляет:
# Multi-Head AttentionQ = X @ W_Q # [batch, seq, n_heads * head_dim]K = X @ W_K # [batch, seq, n_heads * head_dim]V = X @ W_V # [batch, seq, n_heads * head_dim]
# Изменение формы для головQ = Q.view(batch, seq, n_heads, head_dim)K = K.view(batch, seq, n_heads, head_dim)V = V.view(batch, seq, n_heads, head_dim)
# Внимание для каждой головыfor h in range(n_heads): attn_h = softmax(Q[:,:,h,:] @ K[:,:,h,:].T / sqrt(head_dim)) out_h = attn_h @ V[:,:,h,:]Каждая голова имеет свои Q, K, V проекции, что даёт максимальную выразительность, но требует больших KV кэшей при инференсе.
Multi-Query Attention
MQA использует один K и V для всех голов:
# Multi-Query AttentionQ = X @ W_Q # [batch, seq, n_heads * head_dim]K = X @ W_K # [batch, seq, head_dim] <- Один!V = X @ W_V # [batch, seq, head_dim] <- Один!
# K, V не требуют многоголового изменения формы
# Внимание - K,V общие для всех головfor h in range(n_heads): attn_h = softmax(Q[:,:,h,:] @ K.T / sqrt(head_dim)) out_h = attn_h @ VЭто драматически уменьшает KV кэш, но может ухудшить качество.
Grouped Query Attention
GQA группирует головы запросов для совместного использования K/V:
# Grouped Query Attentionn_heads = 8 # Головы Queryn_kv_heads = 2 # Головы KV (группы)n_groups = n_heads // n_kv_heads # 4 запроса на группу KV
Q = X @ W_Q # [batch, seq, n_heads * head_dim]K = X @ W_K # [batch, seq, n_kv_heads * head_dim]V = X @ W_V # [batch, seq, n_kv_heads * head_dim]
# Изменение формыQ = Q.view(batch, seq, n_heads, head_dim)K = K.view(batch, seq, n_kv_heads, head_dim)V = V.view(batch, seq, n_kv_heads, head_dim)
# Расширение K, V для соответствия количеству голов Q# Каждая голова KV обслуживает несколько голов QK = K.repeat_interleave(n_groups, dim=2) # [batch, seq, n_heads, head_dim]V = V.repeat_interleave(n_groups, dim=2) # [batch, seq, n_heads, head_dim]
# Стандартное вычисление вниманияattn = softmax(Q @ K.transpose(-2, -1) / sqrt(head_dim))out = attn @ VОптимизация Key-Value кэша
Основное преимущество GQA проявляется при авторегрессионной генерации:
Сравнение KV кэша (для инференса):+------------------------------------------------------------------------------+| || Сценарий: 8 голов внимания, 128-мер на голову, 4096 длина последовательности|| || MHA KV кэш: || cache_size = 4096 x 8 x 128 x 2 = 8 МБ на слой || || MQA KV кэш: || cache_size = 4096 x 1 x 128 x 2 = 1 МБ на слой (в 8 раз меньше) || || GQA KV кэш (2 группы): || cache_size = 4096 x 2 x 128 x 2 = 2 МБ на слой (в 4 раза меньше чем MHA) || || GQA KV кэш (4 группы): || cache_size = 4096 x 4 x 128 x 2 = 4 МБ на слой (в 2 раза меньше чем MHA) || |+------------------------------------------------------------------------------+Применение в трейдинге
Прогнозирование цен в реальном времени
GQA позволяет быстрее выполнять инференс для прогнозирования в реальном времени:
import torchfrom gqa_trading import GQATrader
# Конфигурация для торговли криптовалютой в реальном времениconfig = { 'context_length': 512, # Недавняя история рынка 'd_model': 256, 'n_heads': 8, 'n_kv_heads': 2, # GQA с 4x уменьшением KV 'n_layers': 6, 'symbols': ['BTCUSDT', 'ETHUSDT', 'SOLUSDT', 'BNBUSDT'], 'data_source': 'bybit',}
model = GQATrader(**config)
# Сравнение скорости инференса:# MHA: ~15мс на прогноз# GQA: ~5мс на прогноз (в 3 раза быстрее!)Высокочастотный трейдинг
Для HFT задержка — это всё:
class HFTGQAPredictor: """ Высокочастотный трейдинг с оптимизацией GQA.
Ключевые оптимизации: 1. GQA уменьшает пропускную способность памяти для KV кэша 2. Меньший кэш позволяет обрабатывать большие батчи 3. Стабильно низкая задержка инференса """
def __init__(self, model, max_batch_size=64): self.model = model self.kv_cache = {} # Предварительно выделенный KV кэш
# Предварительное выделение кэша для каждого слоя for layer_idx in range(model.n_layers): self.kv_cache[layer_idx] = { 'K': torch.zeros(max_batch_size, 512, model.n_kv_heads, model.head_dim), 'V': torch.zeros(max_batch_size, 512, model.n_kv_heads, model.head_dim) }
def predict(self, market_state, use_cache=True): """ Прогноз с кэшированными KV значениями.
Преимущества GQA для HFT: - В 4 раза меньше чтений кэша на токен - Меньше пропускной способности памяти = меньше задержки - Больше запаса для параллельных прогнозов """ if use_cache: return self._predict_with_cache(market_state) return self._predict_fresh(market_state)Инференс для мультиактивных портфелей
Эффективность памяти GQA позволяет анализировать больше активов одновременно:
class MultiAssetGQAPortfolio: """ Мультиактивный анализ портфеля с GQA.
С 50 активами, 512 временными шагами, 8 головами: - MHA KV кэш: 50 * 512 * 8 * 128 * 2 = 52 МБ на слой - GQA KV кэш (2 группы): 50 * 512 * 2 * 128 * 2 = 13 МБ на слой
Это 4x уменьшение позволяет: - Запускать большие батчи - Обрабатывать больше активов параллельно - Помещать больше слоёв в память GPU """
def __init__(self, n_assets=50, lookback=512): self.model = GQATransformer( input_dim=n_assets * 5, # 5 признаков на актив d_model=256, n_heads=8, n_kv_heads=2, # GQA n_layers=6, n_outputs=n_assets )Практические примеры
Подробные примеры кода доступны в директории python/.
Подготовка данных
from data import fetch_bybit_klines, prepare_gqa_data
# Получение данных с Bybitsymbols = ['BTCUSDT', 'ETHUSDT', 'SOLUSDT']data = prepare_gqa_data(symbols, lookback=512, horizon=24)
print(f"X shape: {data['X'].shape}")print(f"y shape: {data['y'].shape}")Обучение модели
from model import GQATraderfrom train import train_model
# Создание модели с GQAmodel = GQATrader( input_dim=len(symbols) * 5, d_model=256, n_heads=8, n_kv_heads=2, # 4x уменьшение KV кэша n_layers=6, n_outputs=len(symbols))
# Обучениеhistory = train_model(model, train_loader, val_loader, epochs=50)Бэктестинг
from strategy import backtest_gqa_strategy
result = backtest_gqa_strategy( model=model, test_data=test_data, symbols=symbols, initial_capital=100000)
print(f"Sharpe Ratio: {result.sharpe_ratio:.2f}")print(f"Max Drawdown: {result.max_drawdown:.2%}")Реализация на Python
python/├── __init__.py├── model.py # Реализация GQA Transformer├── data.py # Загрузка данных Bybit/Yahoo├── train.py # Скрипт обучения├── predict.py # Утилиты прогнозирования├── strategy.py # Фреймворк бэктестинга└── requirements.txt # ЗависимостиБыстрый старт (Python)
cd pythonpip install -r requirements.txtpython train.py --epochs 50python strategy.py --model best_gqa_model.ptРеализация на Rust
Смотрите rust/ для продакшн реализации на Rust.
rust/├── Cargo.toml├── src/│ ├── lib.rs│ ├── attention/│ │ ├── mod.rs│ │ ├── mha.rs│ │ └── gqa.rs│ ├── model/│ ├── data/│ └── strategy/└── examples/Быстрый старт (Rust)
cd rustcargo build --releasecargo run --example benchmarkБенчмарки производительности
Сравнение размера KV кэша
| Конфигурация | MHA кэш | GQA кэш | Уменьшение |
|---|---|---|---|
| 8 голов, 2 KV головы | 8 МБ/слой | 2 МБ/слой | 4x |
| 32 головы, 4 KV головы | 32 МБ/слой | 4 МБ/слой | 8x |
| 32 головы, 8 KV голов | 32 МБ/слой | 8 МБ/слой | 4x |
Скорость инференса
| Размер модели | MHA задержка | GQA задержка | Ускорение |
|---|---|---|---|
| 256M параметров | 15 мс | 8 мс | 1.9x |
| 1B параметров | 45 мс | 18 мс | 2.5x |
| 7B параметров | 180 мс | 55 мс | 3.3x |
Производительность торговой модели
| Метрика | MHA модель | GQA модель | Примечания |
|---|---|---|---|
| MSE | 0.0012 | 0.0013 | Небольшой компромисс качества |
| Точность направления | 54.2% | 53.8% | Минимальная разница |
| Sharpe Ratio | 1.45 | 1.42 | Сопоставимая производительность |
| Инференс (мс) | 15.2 | 5.8 | В 2.6 раза быстрее |
| Память (МБ) | 480 | 180 | В 2.7 раза меньше |
Лучшие практики
Выбор количества KV голов
# Рекомендуемые конфигурацииconfigs = { # Для максимальной скорости (агрессивное сжатие) 'speed_focused': { 'n_heads': 8, 'n_kv_heads': 1, # Похоже на MQA # Может ухудшить качество },
# Сбалансированный (рекомендуется для большинства случаев) 'balanced': { 'n_heads': 8, 'n_kv_heads': 2, # 4x уменьшение, хорошее качество # Лучший баланс скорости/качества },
# Фокус на качестве (минимальное сжатие) 'quality_focused': { 'n_heads': 8, 'n_kv_heads': 4, # 2x уменьшение # Почти качество MHA }}Когда использовать GQA
Рекомендуемые сценарии:
- Инференс в реальном времени, где важна задержка
- Продакшн развёртывание с ограничениями памяти
- Высокопропускной батчевый инференс
- Мультиактивный анализ с многими символами
Может не требоваться:
- Офлайн анализ, где скорость не критична
- Маленькие модели, где накладные расходы MHA незначительны
- Когда максимальное качество модели превыше всего
Распространённые ошибки
-
Неправильная делимость голов: Убедитесь, что
n_heads % n_kv_heads == 0 -
Неиспользование KV кэша при инференсе: Основное преимущество GQA — меньший KV кэш
-
Слишком агрессивное сжатие: Использование 1 KV головы может значительно ухудшить качество
Ресурсы
Статьи
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints — Оригинальная статья GQA (Ainslie et al., 2023)
- Fast Transformer Decoding: One Write-Head is All You Need — Статья MQA (Shazeer, 2019)
- Attention Is All You Need — Оригинальный Transformer (Vaswani et al., 2017)
Реализации
- Llama 2 — Использует GQA
- Mistral — Использует GQA со скользящим окном
- HuggingFace Transformers — Поддержка GQA
Связанные главы
- Глава 58: Flash Attention Trading — Дополнительная оптимизация
- Глава 60: KV Cache Optimization — Дальнейшие оптимизации кэша
- Глава 51: Linformer Long Sequences — Альтернативное эффективное внимание
Уровень сложности
Средний - Продвинутый
Предварительные требования:
- Механизм многоголового внимания
- Архитектура Transformer
- Базовые концепции памяти GPU
- PyTorch или аналогичный фреймворк
- Основы торговых стратегий