Chapter 47: Cross-Attention for Multi-Asset Trading
Chapter 47: Cross-Attention for Multi-Asset Trading
This chapter explores Cross-Attention mechanisms for modeling relationships between multiple financial assets simultaneously. Unlike traditional single-asset forecasting, cross-attention enables the model to capture inter-asset dependencies, correlations, and lead-lag relationships that are crucial for portfolio management and multi-asset trading strategies.
Contents
- Introduction to Cross-Attention
- Cross-Attention Architecture
- Mathematical Foundation
- Data Representation
- Practical Examples
- Rust Implementation
- Python Implementation
- Best Practices
- Resources
Introduction to Cross-Attention
Cross-attention is an attention mechanism where queries come from one sequence (or asset) while keys and values come from another. In multi-asset trading, this allows each asset to “attend to” other assets, learning which assets provide predictive information for others.
Why Cross-Attention for Multi-Asset Trading?
Traditional approaches treat each asset independently:
Asset A → Model_A → Prediction_AAsset B → Model_B → Prediction_BAsset C → Model_C → Prediction_CCross-attention models all assets jointly:
┌─────────────────────────────────────────────────┐│ Cross-Attention Network ││ ││ Asset A ←→ Asset B ←→ Asset C ││ ↑ ↑ ↑ ││ └───────────┴───────────┘ ││ Bidirectional attention ││ ││ ↓ ││ [Prediction_A, Prediction_B, Prediction_C] │└─────────────────────────────────────────────────┘Key insight: Financial markets are interconnected. When Bitcoin moves, Ethereum often follows. When oil prices rise, airline stocks typically fall. Cross-attention explicitly models these dependencies.
Key Advantages
-
Inter-Asset Dependency Learning
- Captures correlations between different asset classes
- Models lead-lag relationships (e.g., BTC leading altcoins)
- Learns time-varying relationships
-
Attention-Based Interpretability
- Attention weights reveal which assets influence predictions
- Visualize cross-asset information flow
- Identify market leaders and followers
-
Portfolio-Level Optimization
- Optimize Sharpe ratio directly instead of individual predictions
- Learn optimal asset allocation weights
- Account for diversification benefits
-
Adaptive Regime Detection
- Attention patterns change during different market regimes
- Detect correlation breakdowns during crises
- Adapt to structural market changes
Comparison with Other Approaches
| Feature | Single-Asset LSTM | Multi-Asset RNN | Self-Attention | Cross-Attention |
|---|---|---|---|---|
| Inter-asset modeling | ✗ | Limited | Implicit | ✓ Explicit |
| Bidirectional influence | ✗ | ✗ | ✓ | ✓ |
| Asymmetric relationships | ✗ | ✗ | ✗ | ✓ |
| Lead-lag detection | ✗ | ✗ | Limited | ✓ |
| Interpretable | ✗ | ✗ | ✓ | ✓ |
| Portfolio optimization | ✗ | ✗ | ✗ | ✓ |
Cross-Attention Architecture
┌──────────────────────────────────────────────────────────────────────────┐│ CROSS-ATTENTION MULTI-ASSET MODEL │├──────────────────────────────────────────────────────────────────────────┤│ ││ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ││ │ BTC │ │ ETH │ │ SOL │ │ AAPL │ ││ │ (Query) │ │ (Query) │ │ (Query) │ │ (Query) │ ││ └────┬────┘ └────┬────┘ └────┬────┘ └────┬────┘ ││ │ │ │ │ ││ ▼ ▼ ▼ ▼ ││ ┌──────────────────────────────────────────────────┐ ││ │ Token Embedding Layer │ ││ │ (1D-CNN or Linear projection per asset) │ ││ └───────────────────────┬──────────────────────────┘ ││ │ ││ ▼ ││ ┌──────────────────────────────────────────────────┐ ││ │ Temporal Self-Attention │ ││ │ (Model temporal patterns within each asset) │ ││ └───────────────────────┬──────────────────────────┘ ││ │ ││ ▼ ││ ┌──────────────────────────────────────────────────┐ ││ │ Cross-Asset Cross-Attention │ ││ │ │ ││ │ Q(BTC) attends to K,V(ETH), K,V(SOL), K,V(AAPL)│ ││ │ Q(ETH) attends to K,V(BTC), K,V(SOL), K,V(AAPL)│ ││ │ ... │ ││ │ │ ││ │ Learns: "BTC leads ETH with weight 0.7" │ ││ │ "ETH leads SOL with weight 0.5" │ ││ └───────────────────────┬──────────────────────────┘ ││ │ ││ ▼ ││ ┌──────────────────────────────────────────────────┐ ││ │ Encoder Stack (N layers) │ ││ │ Temporal Attention + Cross-Asset Attention │ ││ └───────────────────────┬──────────────────────────┘ ││ │ ││ ▼ ││ ┌──────────────────────────────────────────────────┐ ││ │ Prediction Heads │ ││ │ • Returns prediction (regression) │ ││ │ • Direction prediction (classification) │ ││ │ • Portfolio weights (softmax/tanh) │ ││ └──────────────────────────────────────────────────┘ ││ │└──────────────────────────────────────────────────────────────────────────┘Query-Key-Value Mechanism
In cross-attention, one asset generates queries while other assets provide keys and values:
class CrossAssetAttention(nn.Module): def __init__(self, d_model, n_heads, n_assets): super().__init__() self.n_heads = n_heads self.head_dim = d_model // n_heads self.scale = math.sqrt(self.head_dim)
# Separate projections for each role self.query_proj = nn.Linear(d_model, d_model) self.key_proj = nn.Linear(d_model, d_model) self.value_proj = nn.Linear(d_model, d_model) self.output_proj = nn.Linear(d_model, d_model)
def forward(self, query_asset, key_value_assets): """ Args: query_asset: [batch, seq_len, d_model] - Asset to predict key_value_assets: [batch, n_other_assets, seq_len, d_model]
Returns: context: [batch, seq_len, d_model] - Attended representation attention: [batch, n_heads, seq_len, n_other_assets] """ batch, seq_len, d_model = query_asset.shape n_other = key_value_assets.shape[1]
# Project queries from target asset Q = self.query_proj(query_asset)
# Project keys and values from other assets K = self.key_proj(key_value_assets.view(-1, seq_len, d_model)) V = self.value_proj(key_value_assets.view(-1, seq_len, d_model))
# Reshape for multi-head attention Q = Q.view(batch, seq_len, self.n_heads, self.head_dim).transpose(1, 2) K = K.view(batch, n_other, seq_len, self.n_heads, self.head_dim) V = V.view(batch, n_other, seq_len, self.n_heads, self.head_dim)
# Cross-attention: each query position attends to all positions of all other assets # Simplified: attend to last timestep of other assets K_last = K[:, :, -1, :, :].transpose(1, 2) # [batch, n_heads, n_other, head_dim] V_last = V[:, :, -1, :, :].transpose(1, 2)
# Attention scores scores = torch.matmul(Q, K_last.transpose(-2, -1)) / self.scale attention = F.softmax(scores, dim=-1)
# Weighted values context = torch.matmul(attention, V_last) context = context.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
return self.output_proj(context), attentionMulti-Head Cross-Attention
Multiple attention heads capture different types of cross-asset relationships:
class MultiHeadCrossAttention(nn.Module): """ Multi-head cross-attention with different heads specializing in: - Correlation-based relationships - Lead-lag relationships - Volatility spillover - Sector/industry groupings """
def __init__(self, d_model, n_heads, dropout=0.1): super().__init__() self.n_heads = n_heads self.head_dim = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model)
def forward(self, x_query, x_key_value, mask=None): """ Args: x_query: [batch, n_query_assets, seq_len, d_model] x_key_value: [batch, n_kv_assets, seq_len, d_model]
Returns: output: [batch, n_query_assets, seq_len, d_model] attention: [batch, n_heads, n_query_assets, n_kv_assets] """ batch, n_q, seq_len, d_model = x_query.shape n_kv = x_key_value.shape[1]
# Pool temporal dimension for cross-asset attention q = x_query.mean(dim=2) # [batch, n_q, d_model] k = x_key_value.mean(dim=2) # [batch, n_kv, d_model] v = x_key_value.mean(dim=2)
# Project Q = self.W_q(q).view(batch, n_q, self.n_heads, self.head_dim).transpose(1, 2) K = self.W_k(k).view(batch, n_kv, self.n_heads, self.head_dim).transpose(1, 2) V = self.W_v(v).view(batch, n_kv, self.n_heads, self.head_dim).transpose(1, 2)
# Attention scores: [batch, n_heads, n_q, n_kv] scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None: scores = scores.masked_fill(mask == 0, -1e9)
attention = F.softmax(scores, dim=-1) attention = self.dropout(attention)
# Weighted sum: [batch, n_heads, n_q, head_dim] context = torch.matmul(attention, V)
# Reshape and project context = context.transpose(1, 2).contiguous().view(batch, n_q, d_model) output = self.W_o(context)
# Broadcast back to sequence length output = output.unsqueeze(2).expand(-1, -1, seq_len, -1) output = self.layer_norm(x_query + output)
return output, attentionTemporal Cross-Attention
Captures lead-lag relationships across time:
class TemporalCrossAttention(nn.Module): """ Cross-attention that considers temporal shifts between assets.
Example: BTC at time t-1 predicts ETH at time t """
def __init__(self, d_model, n_heads, max_lag=5): super().__init__() self.max_lag = max_lag self.attention = MultiHeadCrossAttention(d_model, n_heads)
# Learnable lag weights self.lag_weights = nn.Parameter(torch.ones(max_lag + 1) / (max_lag + 1))
def forward(self, x_query, x_key_value): """ Args: x_query: [batch, n_q, seq_len, d_model] x_key_value: [batch, n_kv, seq_len, d_model]
Returns: output: Attended representation with temporal alignment attention: Cross-asset attention weights per lag """ batch, n_q, seq_len, d_model = x_query.shape n_kv = x_key_value.shape[1]
outputs = [] attentions = []
# Compute attention at different lags for lag in range(self.max_lag + 1): if lag == 0: kv_lagged = x_key_value else: # Shift key_value backward by lag steps kv_lagged = F.pad(x_key_value[:, :, :-lag], (0, 0, lag, 0))
out, attn = self.attention(x_query, kv_lagged) outputs.append(out) attentions.append(attn)
# Weighted combination across lags lag_weights = F.softmax(self.lag_weights, dim=0) output = sum(w * o for w, o in zip(lag_weights, outputs))
return output, torch.stack(attentions, dim=1)Hierarchical Cross-Attention
Models relationships at multiple levels (assets, sectors, markets):
class HierarchicalCrossAttention(nn.Module): """ Three-level hierarchy: 1. Asset level: Individual asset relationships 2. Sector level: Sector/industry relationships 3. Market level: Cross-market relationships (crypto vs stocks) """
def __init__(self, d_model, n_heads, sector_mapping, market_mapping): super().__init__() self.sector_mapping = sector_mapping # asset_id -> sector_id self.market_mapping = market_mapping # asset_id -> market_id
# Asset-level attention self.asset_attention = MultiHeadCrossAttention(d_model, n_heads)
# Sector-level attention self.sector_attention = MultiHeadCrossAttention(d_model, n_heads // 2)
# Market-level attention self.market_attention = MultiHeadCrossAttention(d_model, n_heads // 4)
# Combine hierarchies self.combine = nn.Linear(d_model * 3, d_model)
def forward(self, x): """ Args: x: [batch, n_assets, seq_len, d_model]
Returns: output: Hierarchically attended representation """ # Asset-level cross-attention asset_out, _ = self.asset_attention(x, x)
# Aggregate to sectors sector_repr = self._aggregate_to_sectors(x) sector_out, _ = self.sector_attention(sector_repr, sector_repr) sector_out = self._broadcast_from_sectors(sector_out, x.shape)
# Aggregate to markets market_repr = self._aggregate_to_markets(x) market_out, _ = self.market_attention(market_repr, market_repr) market_out = self._broadcast_from_markets(market_out, x.shape)
# Combine all levels combined = torch.cat([asset_out, sector_out, market_out], dim=-1) return self.combine(combined)Mathematical Foundation
Attention Score Computation
The attention score between query asset $i$ and key asset $j$ is:
$$\text{Attention}(Q_i, K_j, V_j) = \text{softmax}\left(\frac{Q_i K_j^T}{\sqrt{d_k}}\right) V_j$$
Where:
- $Q_i \in \mathbb{R}^{T \times d_k}$ - Query representations for asset $i$
- $K_j \in \mathbb{R}^{T \times d_k}$ - Key representations for asset $j$
- $V_j \in \mathbb{R}^{T \times d_v}$ - Value representations for asset $j$
- $d_k$ - Dimension of keys (scaling factor)
Cross-Attention vs Self-Attention
| Aspect | Self-Attention | Cross-Attention |
|---|---|---|
| Q, K, V source | Same sequence | Q from one, K/V from another |
| Use case | Temporal patterns | Inter-asset relationships |
| Symmetry | Symmetric | Can be asymmetric |
| Complexity | $O(T^2)$ | $O(T^2 \cdot N)$ for N assets |
Scaled Dot-Product Attention
For multi-asset scenarios with $N$ assets:
$$\text{MultiAssetAttention}(X) = \text{Concat}(\text{head}_1, …, \text{head}_h) W^O$$
Where each head $i$ computes:
$$\text{head}_i = \text{softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_k}}\right) V_i$$
Data Representation
Multi-Asset Feature Engineering
def create_multi_asset_features(df_dict: dict, lookback: int = 100) -> np.ndarray: """ Create feature tensor for multiple assets.
Args: df_dict: Dictionary mapping asset symbol to DataFrame with OHLCV lookback: Number of historical timesteps
Returns: features: [n_samples, n_assets, lookback, n_features] """ features = []
for symbol, df in df_dict.items(): asset_features = []
# Price features asset_features.append(np.log(df['close'] / df['close'].shift(1))) # Log returns asset_features.append((df['close'] - df['open']) / df['open']) # Intraday return asset_features.append((df['high'] - df['low']) / df['close']) # Range
# Volume features asset_features.append(df['volume'] / df['volume'].rolling(20).mean()) # Relative volume
# Technical indicators asset_features.append(compute_rsi(df['close'], 14)) asset_features.append(compute_macd(df['close']))
features.append(np.column_stack(asset_features))
return np.stack(features, axis=1) # [time, n_assets, n_features]Data from Stock Markets
import yfinance as yf
def fetch_stock_data(symbols: list, start: str, end: str) -> dict: """ Fetch stock data from Yahoo Finance.
Args: symbols: List of stock symbols (e.g., ['AAPL', 'GOOGL', 'MSFT']) start: Start date (YYYY-MM-DD) end: End date (YYYY-MM-DD)
Returns: Dictionary mapping symbol to DataFrame """ data = {}
for symbol in symbols: ticker = yf.Ticker(symbol) df = ticker.history(start=start, end=end, interval='1h') df.columns = df.columns.str.lower() data[symbol] = df
return data
# Example usagestock_symbols = ['AAPL', 'GOOGL', 'MSFT', 'AMZN', 'NVDA']stock_data = fetch_stock_data(stock_symbols, '2023-01-01', '2024-01-01')Data from Cryptocurrency Markets (Bybit)
import requestsimport pandas as pdfrom datetime import datetime, timedelta
class BybitDataLoader: """Load cryptocurrency data from Bybit exchange."""
BASE_URL = "https://api.bybit.com/v5/market/kline"
def __init__(self): self.session = requests.Session()
def fetch_klines( self, symbol: str, interval: str = "60", # 60 minutes = 1 hour limit: int = 1000 ) -> pd.DataFrame: """ Fetch kline/candlestick data from Bybit.
Args: symbol: Trading pair (e.g., 'BTCUSDT') interval: Kline interval (1, 3, 5, 15, 30, 60, 120, 240, 360, 720, D, W, M) limit: Number of candles (max 1000)
Returns: DataFrame with OHLCV data """ params = { 'category': 'linear', 'symbol': symbol, 'interval': interval, 'limit': limit }
response = self.session.get(self.BASE_URL, params=params) data = response.json()
if data['retCode'] != 0: raise Exception(f"API Error: {data['retMsg']}")
klines = data['result']['list']
df = pd.DataFrame(klines, columns=[ 'timestamp', 'open', 'high', 'low', 'close', 'volume', 'turnover' ])
df['timestamp'] = pd.to_datetime(df['timestamp'].astype(int), unit='ms') for col in ['open', 'high', 'low', 'close', 'volume', 'turnover']: df[col] = df[col].astype(float)
return df.sort_values('timestamp').reset_index(drop=True)
def fetch_multi_asset(self, symbols: list, **kwargs) -> dict: """Fetch data for multiple assets.""" return {symbol: self.fetch_klines(symbol, **kwargs) for symbol in symbols}
# Example usageloader = BybitDataLoader()crypto_symbols = ['BTCUSDT', 'ETHUSDT', 'SOLUSDT', 'AVAXUSDT', 'DOTUSDT']crypto_data = loader.fetch_multi_asset(crypto_symbols, interval='60', limit=1000)Practical Examples
01: Data Preparation
import numpy as npimport pandas as pdfrom typing import List, Dict, Tuplefrom sklearn.preprocessing import StandardScaler
def prepare_cross_attention_data( asset_data: Dict[str, pd.DataFrame], lookback: int = 168, # 7 days hourly horizon: int = 24, # 24 hours ahead features: List[str] = ['log_return', 'volume_ratio', 'volatility', 'rsi']) -> Tuple[np.ndarray, np.ndarray, List[str]]: """ Prepare data for cross-attention multi-asset model.
Returns: X: [n_samples, n_assets, lookback, n_features] y: [n_samples, n_assets] - Future returns symbols: List of asset symbols """ symbols = list(asset_data.keys()) n_assets = len(symbols)
# Compute features for each asset processed = {} for symbol, df in asset_data.items(): feat = pd.DataFrame(index=df.index)
feat['log_return'] = np.log(df['close'] / df['close'].shift(1)) feat['volume_ratio'] = df['volume'] / df['volume'].rolling(20).mean() feat['volatility'] = feat['log_return'].rolling(20).std() feat['rsi'] = compute_rsi(df['close'], 14)
processed[symbol] = feat
# Align timestamps common_idx = processed[symbols[0]].index for symbol in symbols[1:]: common_idx = common_idx.intersection(processed[symbol].index)
# Create sequences X, y = [], [] for i in range(lookback, len(common_idx) - horizon): x_sample = [] y_sample = []
for symbol in symbols: df = processed[symbol].loc[common_idx] x_sample.append(df.iloc[i-lookback:i][features].values) y_sample.append(df.iloc[i+horizon]['log_return'])
X.append(np.stack(x_sample, axis=0)) y.append(np.array(y_sample))
return np.array(X), np.array(y), symbols
def compute_rsi(prices: pd.Series, period: int = 14) -> pd.Series: """Compute Relative Strength Index.""" delta = prices.diff() gain = (delta.where(delta > 0, 0)).rolling(window=period).mean() loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean() rs = gain / loss return 100 - (100 / (1 + rs))02: Cross-Attention Model
See python/model.py for complete implementation.
# python/model.py (simplified)
import torchimport torch.nn as nnimport torch.nn.functional as Fimport math
class CrossAttentionMultiAsset(nn.Module): """ Cross-Attention model for multi-asset prediction.
Features: - Temporal self-attention within each asset - Cross-asset attention between all pairs - Multi-head attention for diverse relationships - Flexible output: regression, classification, or portfolio weights """
def __init__( self, n_assets: int, n_features: int, d_model: int = 64, n_heads: int = 4, n_layers: int = 2, dropout: float = 0.1, output_type: str = 'regression' ): super().__init__() self.n_assets = n_assets self.output_type = output_type
# Embedding self.input_proj = nn.Linear(n_features, d_model) self.pos_encoding = PositionalEncoding(d_model, dropout)
# Encoder layers self.layers = nn.ModuleList([ CrossAttentionLayer(d_model, n_heads, dropout) for _ in range(n_layers) ])
# Output head if output_type == 'regression': self.output_head = nn.Linear(d_model, 1) elif output_type == 'classification': self.output_head = nn.Linear(d_model, 3) # Down, Neutral, Up elif output_type == 'portfolio': self.output_head = nn.Linear(d_model, 1)
def forward(self, x, return_attention=False): """ Args: x: [batch, n_assets, seq_len, n_features]
Returns: predictions: [batch, n_assets] or [batch, n_assets, n_classes] attention: Optional attention weights """ batch, n_assets, seq_len, n_features = x.shape
# Embed each asset x = self.input_proj(x) # [batch, n_assets, seq_len, d_model]
# Add positional encoding for a in range(n_assets): x[:, a] = self.pos_encoding(x[:, a])
# Apply encoder layers attentions = [] for layer in self.layers: x, attn = layer(x, return_attention) if return_attention: attentions.append(attn)
# Pool temporal dimension x = x[:, :, -1, :] # Take last timestep: [batch, n_assets, d_model]
# Output if self.output_type == 'portfolio': logits = self.output_head(x).squeeze(-1) # [batch, n_assets] output = F.softmax(logits, dim=-1) # Portfolio weights elif self.output_type == 'classification': output = self.output_head(x) # [batch, n_assets, 3] else: output = self.output_head(x).squeeze(-1) # [batch, n_assets]
if return_attention: return output, attentions return output03: Model Training
import torchimport torch.nn as nnfrom torch.utils.data import DataLoader, TensorDataset
def train_cross_attention_model( model: nn.Module, train_data: tuple, val_data: tuple, epochs: int = 100, batch_size: int = 32, lr: float = 0.001, device: str = 'cuda'): """ Train cross-attention model.
Args: model: CrossAttentionMultiAsset model train_data: (X_train, y_train) val_data: (X_val, y_val) """ X_train, y_train = train_data X_val, y_val = val_data
# Create data loaders train_dataset = TensorDataset( torch.FloatTensor(X_train), torch.FloatTensor(y_train) ) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Loss and optimizer if model.output_type == 'regression': criterion = nn.MSELoss() elif model.output_type == 'classification': criterion = nn.CrossEntropyLoss() else: # portfolio criterion = lambda pred, ret: -torch.mean(torch.sum(pred * ret, dim=-1)) # Negative returns
optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=5 )
model = model.to(device) best_val_loss = float('inf')
for epoch in range(epochs): # Training model.train() train_loss = 0.0
for batch_x, batch_y in train_loader: batch_x = batch_x.to(device) batch_y = batch_y.to(device)
optimizer.zero_grad() predictions = model(batch_x)
if model.output_type == 'classification': # Reshape for cross-entropy predictions = predictions.view(-1, 3) batch_y = (batch_y > 0).long().view(-1)
loss = criterion(predictions, batch_y) loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step()
train_loss += loss.item()
# Validation model.eval() with torch.no_grad(): val_x = torch.FloatTensor(X_val).to(device) val_y = torch.FloatTensor(y_val).to(device) val_pred = model(val_x)
if model.output_type == 'classification': val_pred = val_pred.view(-1, 3) val_y = (val_y > 0).long().view(-1)
val_loss = criterion(val_pred, val_y).item()
scheduler.step(val_loss)
if val_loss < best_val_loss: best_val_loss = val_loss torch.save(model.state_dict(), 'best_model.pt')
if epoch % 10 == 0: print(f"Epoch {epoch}: Train Loss = {train_loss/len(train_loader):.6f}, " f"Val Loss = {val_loss:.6f}")
return model04: Multi-Asset Prediction
import torchimport numpy as npimport matplotlib.pyplot as pltimport seaborn as sns
def predict_and_visualize( model, X: np.ndarray, symbols: list, device: str = 'cuda'): """ Make predictions and visualize attention patterns. """ model.eval() model = model.to(device)
with torch.no_grad(): x = torch.FloatTensor(X).to(device) predictions, attentions = model(x, return_attention=True)
predictions = predictions.cpu().numpy()
# Visualize cross-asset attention if attentions: # Get attention from last layer cross_attn = attentions[-1]['cross_asset'] # [batch, n_heads, n_assets, n_assets] avg_attn = cross_attn.mean(dim=[0, 1]).cpu().numpy()
plt.figure(figsize=(10, 8)) sns.heatmap( avg_attn, xticklabels=symbols, yticklabels=symbols, annot=True, fmt='.2f', cmap='Blues' ) plt.title('Cross-Asset Attention Weights') plt.xlabel('Key (Source Asset)') plt.ylabel('Query (Target Asset)') plt.tight_layout() plt.savefig('cross_attention_heatmap.png', dpi=150) plt.close()
return predictions
def analyze_lead_lag_relationships( model, X: np.ndarray, symbols: list): """ Analyze which assets lead/lag others based on attention patterns. """ model.eval()
with torch.no_grad(): _, attentions = model(torch.FloatTensor(X), return_attention=True)
# Extract cross-asset attention cross_attn = attentions[-1]['cross_asset'].mean(dim=[0, 1]).numpy()
# Compute influence scores influence = {} for i, symbol in enumerate(symbols): # How much does this asset influence others? influence[symbol] = { 'as_leader': cross_attn[:, i].mean(), # Others attend to this 'as_follower': cross_attn[i, :].mean() # This attends to others }
# Rank by leadership leaders = sorted( influence.items(), key=lambda x: x[1]['as_leader'], reverse=True )
print("\nAsset Leadership Ranking:") print("-" * 40) for symbol, scores in leaders: print(f"{symbol}: Leader={scores['as_leader']:.3f}, " f"Follower={scores['as_follower']:.3f}")
return influence05: Portfolio Backtesting
import numpy as npimport pandas as pdfrom typing import Dict, List
class CrossAttentionBacktest: """ Backtest cross-attention portfolio strategy. """
def __init__( self, model, initial_capital: float = 100000, transaction_cost: float = 0.001, rebalance_freq: int = 24 # Hours ): self.model = model self.initial_capital = initial_capital self.transaction_cost = transaction_cost self.rebalance_freq = rebalance_freq
def run( self, X: np.ndarray, returns: np.ndarray, timestamps: pd.DatetimeIndex ) -> pd.DataFrame: """ Run backtest on test data.
Args: X: [n_samples, n_assets, lookback, n_features] returns: [n_samples, n_assets] - Actual future returns timestamps: DatetimeIndex for results
Returns: DataFrame with portfolio metrics over time """ import torch
self.model.eval() n_samples, n_assets, _, _ = X.shape
capital = self.initial_capital positions = np.zeros(n_assets)
results = []
for i in range(0, n_samples, self.rebalance_freq): # Get model predictions (portfolio weights) with torch.no_grad(): x = torch.FloatTensor(X[i:i+1]) weights = self.model(x).numpy().flatten()
# Normalize weights if self.model.output_type == 'regression': # Convert return predictions to weights weights = np.clip(weights, -1, 1) weights = weights / (np.abs(weights).sum() + 1e-8)
# Calculate transaction costs position_change = np.abs(weights - positions).sum() costs = position_change * self.transaction_cost * capital
# Calculate period returns period_returns = returns[i:min(i+self.rebalance_freq, n_samples)]
for j, ret in enumerate(period_returns): portfolio_return = np.sum(positions * ret) capital = capital * (1 + portfolio_return)
if j == 0: capital -= costs
results.append({ 'timestamp': timestamps[i+j] if i+j < len(timestamps) else None, 'capital': capital, 'return': portfolio_return, 'positions': positions.copy(), 'weights': weights.copy() })
# Update positions positions = weights
return pd.DataFrame(results)
def compute_metrics(self, results: pd.DataFrame) -> Dict: """Compute performance metrics.""" returns = results['return'].values
# Sharpe Ratio (annualized for hourly data) sharpe = np.sqrt(365 * 24) * returns.mean() / (returns.std() + 1e-8)
# Sortino Ratio downside = returns[returns < 0] sortino = np.sqrt(365 * 24) * returns.mean() / (downside.std() + 1e-8)
# Maximum Drawdown cumulative = (1 + returns).cumprod() running_max = np.maximum.accumulate(cumulative) drawdown = (cumulative - running_max) / running_max max_drawdown = drawdown.min()
# Total Return total_return = (results['capital'].iloc[-1] / self.initial_capital - 1) * 100
return { 'total_return': total_return, 'sharpe_ratio': sharpe, 'sortino_ratio': sortino, 'max_drawdown': max_drawdown * 100, 'volatility': returns.std() * np.sqrt(365 * 24) * 100, 'win_rate': (returns > 0).mean() * 100 }Rust Implementation
See rust/ for complete Rust implementation using the candle ML framework.
rust/├── Cargo.toml├── README.md├── src/│ ├── lib.rs # Library exports│ ├── model/ # Model implementation│ │ ├── mod.rs│ │ ├── attention.rs # Cross-attention layers│ │ ├── embedding.rs # Token embeddings│ │ └── cross_attention.rs # Main model│ ├── data/ # Data handling│ │ ├── mod.rs│ │ ├── bybit.rs # Bybit API client│ │ ├── features.rs # Feature engineering│ │ └── dataset.rs # Training dataset│ └── strategy/ # Trading strategy│ ├── mod.rs│ ├── signals.rs # Signal generation│ └── backtest.rs # Backtesting engine└── examples/ ├── fetch_data.rs # Download data from Bybit ├── train.rs # Train the model └── backtest.rs # Run backtestQuick Start (Rust)
# Navigate to Rust projectcd rust
# Fetch data from Bybitcargo run --example fetch_data -- --symbols BTCUSDT,ETHUSDT,SOLUSDT,AVAXUSDT
# Train modelcargo run --release --example train -- --epochs 50 --batch-size 32
# Run backtestcargo run --release --example backtest -- --start 2024-01-01 --end 2024-12-31Python Implementation
See python/ for Python implementation.
python/├── __init__.py├── model.py # Cross-attention model├── data.py # Data loading (Bybit + Yahoo Finance)├── features.py # Feature engineering├── train.py # Training script├── backtest.py # Backtesting utilities├── requirements.txt # Dependencies└── examples/ ├── 01_data_preparation.py ├── 02_model_training.py ├── 03_prediction.py └── 04_backtesting.pyQuick Start (Python)
# Install dependenciespip install -r requirements.txt
# Run examplepython examples/01_data_preparation.pypython examples/02_model_training.pypython examples/03_prediction.pypython examples/04_backtesting.pyBest Practices
When to Use Cross-Attention
Good use cases:
- Trading correlated asset classes (crypto, tech stocks, commodities)
- Portfolio optimization across multiple assets
- Detecting lead-lag relationships
- Multi-asset risk management
Not ideal for:
- Single asset prediction (use simpler models)
- Very short-term prediction (latency concerns)
- Uncorrelated assets (cross-attention won’t help)
Hyperparameter Recommendations
| Parameter | Recommended | Notes |
|---|---|---|
d_model | 64-128 | Match computational budget |
n_heads | 4-8 | More heads for more assets |
n_layers | 2-4 | Deeper for complex relationships |
dropout | 0.1-0.2 | Higher for small datasets |
lookback | 168 (7 days hourly) | Match prediction horizon |
Common Pitfalls
-
Correlation collapse: All attention goes to one dominant asset
- Solution: Use dropout, attention regularization
-
Overfitting cross-asset patterns: Model memorizes spurious correlations
- Solution: More data, simpler model, regularization
-
Ignoring regime changes: Cross-asset relationships change over time
- Solution: Rolling training windows, regime detection
-
Computational cost: O(N² * T²) for N assets, T timesteps
- Solution: Sparse attention, efficient implementations
Resources
Papers
- Portfolio Transformer for Attention-Based Asset Allocation — End-to-end portfolio optimization with attention
- Attention-Based Ensemble Learning for Portfolio Optimisation — MASAAT framework with multi-agent attention
- Large-scale Time-Varying Portfolio Optimisation using Graph Attention Networks — GAT-based portfolio management
- Attention Is All You Need — Original Transformer paper
Implementations
Related Chapters
- Chapter 26: Temporal Fusion Transformers — Multi-horizon forecasting
- Chapter 43: Stockformer Multivariate — Cross-ticker attention
- Chapter 44: ProbSparse Attention — Efficient attention mechanisms
- Chapter 46: Temporal Attention Networks — Temporal attention
Difficulty Level
Advanced
Prerequisites:
- Transformer architecture and attention mechanisms
- Multi-asset portfolio theory
- Time series forecasting
- PyTorch or Rust ML libraries