Перейти к содержимому

Глава 59: Grouped Query Attention для алгоритмического трейдинга

В этой главе рассматривается Grouped Query Attention (GQA) — эффективный механизм внимания, который обеспечивает оптимальный баланс между Multi-Head Attention (MHA) и Multi-Query Attention (MQA). Мы применяем GQA к прогнозированию финансовых временных рядов, демонстрируя, как его эффективность позволяет ускорить инференс для продакшн торговых систем.

Содержание

  1. Введение в Grouped Query Attention
  2. Алгоритм GQA
  3. Применение в трейдинге
  4. Практические примеры
  5. Реализация на Python
  6. Реализация на Rust
  7. Бенчмарки производительности
  8. Лучшие практики
  9. Ресурсы

Введение в Grouped Query Attention

Grouped Query Attention (GQA) был представлен Ainslie et al. (2023) как метод балансировки качества Multi-Head Attention со скоростью Multi-Query Attention. Вместо разделения ключей и значений между всеми головами запросов (MQA) или наличия отдельных K/V для каждой головы (MHA), GQA группирует головы запросов для совместного использования K/V проекций.

Проблема узкого места при инференсе

Во время авторегрессионного инференса (генерация по одному токену за раз) Key-Value (KV) кэш становится значительным узким местом:

Узкое место памяти при инференсе:
+------------------------------------------------------------------------------+
| |
| Размер KV кэша Multi-Head Attention (MHA): |
| ----------------------------------------------- |
| batch_size x seq_len x n_heads x head_dim x 2 (K и V) |
| |
| Пример (стиль Llama-2 7B): |
| - n_heads = 32 |
| - head_dim = 128 |
| - seq_len = 4096 |
| - batch_size = 8 |
| |
| KV кэш = 8 x 4096 x 32 x 128 x 2 = 268 МБ на слой |
| Для 32 слоёв = 8.6 ГБ только для KV кэша! |
| |
+------------------------------------------------------------------------------+

Для торговых систем быстрый инференс критичен:

  • Маркет-мейкинг: Требуются решения за доли миллисекунды
  • Арбитраж: Возможности исчезают за микросекунды
  • Риск в реальном времени: Непрерывный мониторинг позиций
  • Мультиактивный анализ: Много инструментов одновременно

MHA vs MQA vs GQA

Сравнение вариантов внимания:
+------------------------------------------------------------------------------+
| |
| Multi-Head Attention (MHA): |
| +--------+--------+--------+--------+ |
| | Q1 | Q2 | Q3 | Q4 | <- 4 головы Query |
| +--------+--------+--------+--------+ |
| | K1 | K2 | K3 | K4 | <- 4 головы Key (отдельные) |
| +--------+--------+--------+--------+ |
| | V1 | V2 | V3 | V4 | <- 4 головы Value (отдельные) |
| +--------+--------+--------+--------+ |
| Качество: Отличное | Память: 4x | Скорость: Базовая |
| |
| Multi-Query Attention (MQA): |
| +--------+--------+--------+--------+ |
| | Q1 | Q2 | Q3 | Q4 | <- 4 головы Query |
| +--------+--------+--------+--------+ |
| | K (общий) | <- 1 голова Key (общая) |
| +--------+--------+--------+--------+ |
| | V (общий) | <- 1 голова Value (общая) |
| +--------+--------+--------+--------+ |
| Качество: Хорошее (некоторая деградация) | Память: 1x | Скорость: 4x |
| |
| Grouped Query Attention (GQA с 2 группами): |
| +--------+--------+--------+--------+ |
| | Q1 | Q2 | Q3 | Q4 | <- 4 головы Query |
| +--------+--------+--------+--------+ |
| | K1 | K1 | K2 | K2 | <- 2 головы Key (сгруппированные) |
| +--------+--------+--------+--------+ |
| | V1 | V1 | V2 | V2 | <- 2 головы Value (сгруппированные) |
| +--------+--------+--------+--------+ |
| Качество: Очень хорошее | Память: 2x | Скорость: 2x быстрее |
| |
+------------------------------------------------------------------------------+

Преимущества для торговых моделей

ПреимуществоMHAMQAGQAВлияние на трейдинг
КачествоЛучшееХорошееОчень хорошееGQA сохраняет точность прогнозов
Скорость инференса1x4-8x2-4xБыстрее решения в реальном времени
Размер KV кэшаПолный1/HG/HМеньше памяти = больше символов
Размер батчаОграниченБольшойСреднийЛучшая пропускная способность
ЗадержкаВысокаяНизкаяСредне-низкаяПодходит для HFT

Где H = количество голов, G = количество групп.

Алгоритм GQA

Обзор Multi-Head Attention

Стандартный Multi-Head Attention вычисляет:

# Multi-Head Attention
Q = X @ W_Q # [batch, seq, n_heads * head_dim]
K = X @ W_K # [batch, seq, n_heads * head_dim]
V = X @ W_V # [batch, seq, n_heads * head_dim]
# Изменение формы для голов
Q = Q.view(batch, seq, n_heads, head_dim)
K = K.view(batch, seq, n_heads, head_dim)
V = V.view(batch, seq, n_heads, head_dim)
# Внимание для каждой головы
for h in range(n_heads):
attn_h = softmax(Q[:,:,h,:] @ K[:,:,h,:].T / sqrt(head_dim))
out_h = attn_h @ V[:,:,h,:]

Каждая голова имеет свои Q, K, V проекции, что даёт максимальную выразительность, но требует больших KV кэшей при инференсе.

Multi-Query Attention

MQA использует один K и V для всех голов:

# Multi-Query Attention
Q = X @ W_Q # [batch, seq, n_heads * head_dim]
K = X @ W_K # [batch, seq, head_dim] <- Один!
V = X @ W_V # [batch, seq, head_dim] <- Один!
# K, V не требуют многоголового изменения формы
# Внимание - K,V общие для всех голов
for h in range(n_heads):
attn_h = softmax(Q[:,:,h,:] @ K.T / sqrt(head_dim))
out_h = attn_h @ V

Это драматически уменьшает KV кэш, но может ухудшить качество.

Grouped Query Attention

GQA группирует головы запросов для совместного использования K/V:

# Grouped Query Attention
n_heads = 8 # Головы Query
n_kv_heads = 2 # Головы KV (группы)
n_groups = n_heads // n_kv_heads # 4 запроса на группу KV
Q = X @ W_Q # [batch, seq, n_heads * head_dim]
K = X @ W_K # [batch, seq, n_kv_heads * head_dim]
V = X @ W_V # [batch, seq, n_kv_heads * head_dim]
# Изменение формы
Q = Q.view(batch, seq, n_heads, head_dim)
K = K.view(batch, seq, n_kv_heads, head_dim)
V = V.view(batch, seq, n_kv_heads, head_dim)
# Расширение K, V для соответствия количеству голов Q
# Каждая голова KV обслуживает несколько голов Q
K = K.repeat_interleave(n_groups, dim=2) # [batch, seq, n_heads, head_dim]
V = V.repeat_interleave(n_groups, dim=2) # [batch, seq, n_heads, head_dim]
# Стандартное вычисление внимания
attn = softmax(Q @ K.transpose(-2, -1) / sqrt(head_dim))
out = attn @ V

Оптимизация Key-Value кэша

Основное преимущество GQA проявляется при авторегрессионной генерации:

Сравнение KV кэша (для инференса):
+------------------------------------------------------------------------------+
| |
| Сценарий: 8 голов внимания, 128-мер на голову, 4096 длина последовательности|
| |
| MHA KV кэш: |
| cache_size = 4096 x 8 x 128 x 2 = 8 МБ на слой |
| |
| MQA KV кэш: |
| cache_size = 4096 x 1 x 128 x 2 = 1 МБ на слой (в 8 раз меньше) |
| |
| GQA KV кэш (2 группы): |
| cache_size = 4096 x 2 x 128 x 2 = 2 МБ на слой (в 4 раза меньше чем MHA) |
| |
| GQA KV кэш (4 группы): |
| cache_size = 4096 x 4 x 128 x 2 = 4 МБ на слой (в 2 раза меньше чем MHA) |
| |
+------------------------------------------------------------------------------+

Применение в трейдинге

Прогнозирование цен в реальном времени

GQA позволяет быстрее выполнять инференс для прогнозирования в реальном времени:

import torch
from gqa_trading import GQATrader
# Конфигурация для торговли криптовалютой в реальном времени
config = {
'context_length': 512, # Недавняя история рынка
'd_model': 256,
'n_heads': 8,
'n_kv_heads': 2, # GQA с 4x уменьшением KV
'n_layers': 6,
'symbols': ['BTCUSDT', 'ETHUSDT', 'SOLUSDT', 'BNBUSDT'],
'data_source': 'bybit',
}
model = GQATrader(**config)
# Сравнение скорости инференса:
# MHA: ~15мс на прогноз
# GQA: ~5мс на прогноз (в 3 раза быстрее!)

Высокочастотный трейдинг

Для HFT задержка — это всё:

class HFTGQAPredictor:
"""
Высокочастотный трейдинг с оптимизацией GQA.
Ключевые оптимизации:
1. GQA уменьшает пропускную способность памяти для KV кэша
2. Меньший кэш позволяет обрабатывать большие батчи
3. Стабильно низкая задержка инференса
"""
def __init__(self, model, max_batch_size=64):
self.model = model
self.kv_cache = {} # Предварительно выделенный KV кэш
# Предварительное выделение кэша для каждого слоя
for layer_idx in range(model.n_layers):
self.kv_cache[layer_idx] = {
'K': torch.zeros(max_batch_size, 512, model.n_kv_heads, model.head_dim),
'V': torch.zeros(max_batch_size, 512, model.n_kv_heads, model.head_dim)
}
def predict(self, market_state, use_cache=True):
"""
Прогноз с кэшированными KV значениями.
Преимущества GQA для HFT:
- В 4 раза меньше чтений кэша на токен
- Меньше пропускной способности памяти = меньше задержки
- Больше запаса для параллельных прогнозов
"""
if use_cache:
return self._predict_with_cache(market_state)
return self._predict_fresh(market_state)

Инференс для мультиактивных портфелей

Эффективность памяти GQA позволяет анализировать больше активов одновременно:

class MultiAssetGQAPortfolio:
"""
Мультиактивный анализ портфеля с GQA.
С 50 активами, 512 временными шагами, 8 головами:
- MHA KV кэш: 50 * 512 * 8 * 128 * 2 = 52 МБ на слой
- GQA KV кэш (2 группы): 50 * 512 * 2 * 128 * 2 = 13 МБ на слой
Это 4x уменьшение позволяет:
- Запускать большие батчи
- Обрабатывать больше активов параллельно
- Помещать больше слоёв в память GPU
"""
def __init__(self, n_assets=50, lookback=512):
self.model = GQATransformer(
input_dim=n_assets * 5, # 5 признаков на актив
d_model=256,
n_heads=8,
n_kv_heads=2, # GQA
n_layers=6,
n_outputs=n_assets
)

Практические примеры

Подробные примеры кода доступны в директории python/.

Подготовка данных

from data import fetch_bybit_klines, prepare_gqa_data
# Получение данных с Bybit
symbols = ['BTCUSDT', 'ETHUSDT', 'SOLUSDT']
data = prepare_gqa_data(symbols, lookback=512, horizon=24)
print(f"X shape: {data['X'].shape}")
print(f"y shape: {data['y'].shape}")

Обучение модели

from model import GQATrader
from train import train_model
# Создание модели с GQA
model = GQATrader(
input_dim=len(symbols) * 5,
d_model=256,
n_heads=8,
n_kv_heads=2, # 4x уменьшение KV кэша
n_layers=6,
n_outputs=len(symbols)
)
# Обучение
history = train_model(model, train_loader, val_loader, epochs=50)

Бэктестинг

from strategy import backtest_gqa_strategy
result = backtest_gqa_strategy(
model=model,
test_data=test_data,
symbols=symbols,
initial_capital=100000
)
print(f"Sharpe Ratio: {result.sharpe_ratio:.2f}")
print(f"Max Drawdown: {result.max_drawdown:.2%}")

Реализация на Python

python/
├── __init__.py
├── model.py # Реализация GQA Transformer
├── data.py # Загрузка данных Bybit/Yahoo
├── train.py # Скрипт обучения
├── predict.py # Утилиты прогнозирования
├── strategy.py # Фреймворк бэктестинга
└── requirements.txt # Зависимости

Быстрый старт (Python)

Окно терминала
cd python
pip install -r requirements.txt
python train.py --epochs 50
python strategy.py --model best_gqa_model.pt

Реализация на Rust

Смотрите rust/ для продакшн реализации на Rust.

rust/
├── Cargo.toml
├── src/
│ ├── lib.rs
│ ├── attention/
│ │ ├── mod.rs
│ │ ├── mha.rs
│ │ └── gqa.rs
│ ├── model/
│ ├── data/
│ └── strategy/
└── examples/

Быстрый старт (Rust)

Окно терминала
cd rust
cargo build --release
cargo run --example benchmark

Бенчмарки производительности

Сравнение размера KV кэша

КонфигурацияMHA кэшGQA кэшУменьшение
8 голов, 2 KV головы8 МБ/слой2 МБ/слой4x
32 головы, 4 KV головы32 МБ/слой4 МБ/слой8x
32 головы, 8 KV голов32 МБ/слой8 МБ/слой4x

Скорость инференса

Размер моделиMHA задержкаGQA задержкаУскорение
256M параметров15 мс8 мс1.9x
1B параметров45 мс18 мс2.5x
7B параметров180 мс55 мс3.3x

Производительность торговой модели

МетрикаMHA модельGQA модельПримечания
MSE0.00120.0013Небольшой компромисс качества
Точность направления54.2%53.8%Минимальная разница
Sharpe Ratio1.451.42Сопоставимая производительность
Инференс (мс)15.25.8В 2.6 раза быстрее
Память (МБ)480180В 2.7 раза меньше

Лучшие практики

Выбор количества KV голов

# Рекомендуемые конфигурации
configs = {
# Для максимальной скорости (агрессивное сжатие)
'speed_focused': {
'n_heads': 8,
'n_kv_heads': 1, # Похоже на MQA
# Может ухудшить качество
},
# Сбалансированный (рекомендуется для большинства случаев)
'balanced': {
'n_heads': 8,
'n_kv_heads': 2, # 4x уменьшение, хорошее качество
# Лучший баланс скорости/качества
},
# Фокус на качестве (минимальное сжатие)
'quality_focused': {
'n_heads': 8,
'n_kv_heads': 4, # 2x уменьшение
# Почти качество MHA
}
}

Когда использовать GQA

Рекомендуемые сценарии:

  • Инференс в реальном времени, где важна задержка
  • Продакшн развёртывание с ограничениями памяти
  • Высокопропускной батчевый инференс
  • Мультиактивный анализ с многими символами

Может не требоваться:

  • Офлайн анализ, где скорость не критична
  • Маленькие модели, где накладные расходы MHA незначительны
  • Когда максимальное качество модели превыше всего

Распространённые ошибки

  1. Неправильная делимость голов: Убедитесь, что n_heads % n_kv_heads == 0

  2. Неиспользование KV кэша при инференсе: Основное преимущество GQA — меньший KV кэш

  3. Слишком агрессивное сжатие: Использование 1 KV головы может значительно ухудшить качество

Ресурсы

Статьи

Реализации

Связанные главы


Уровень сложности

Средний - Продвинутый

Предварительные требования:

  • Механизм многоголового внимания
  • Архитектура Transformer
  • Базовые концепции памяти GPU
  • PyTorch или аналогичный фреймворк
  • Основы торговых стратегий