Long Context Extension¶
개요¶
LLM의 컨텍스트 길이를 확장하는 기법들에 대한 문서다. 기본 Transformer의 어텐션은 시퀀스 길이의 제곱에 비례하는 메모리와 계산이 필요하므로, 긴 컨텍스트를 효율적으로 처리하기 위한 다양한 방법이 개발되었다.
핵심 개념¶
위치 인코딩의 중요성¶
Transformer는 위치 정보가 없으므로 위치 인코딩이 필수다. 긴 컨텍스트에서는 위치 인코딩이 학습 범위를 벗어난 위치에도 일반화되어야 한다.
RoPE (Rotary Position Embedding)¶
위치 정보를 복소수 회전으로 인코딩:
q_m = R_m * q
k_n = R_n * k
where R_theta is rotation matrix:
R_theta = [[cos(theta), -sin(theta)],
[sin(theta), cos(theta)]]
Attention becomes:
a_{m,n} = Re[(R_m * q)^* (R_n * k)]
= Re[q^* R_{n-m} k]
= f(m-n) # 상대 위치만 의존
ALiBi (Attention with Linear Biases)¶
어텐션 점수에 거리 기반 선형 편향 추가:
Attention(Q, K, V) = softmax(QK^T / sqrt(d) - m * |i-j|) V
where:
m = head-specific slope
|i-j| = position distance
아키텍처 다이어그램¶
RoPE 적용 과정¶
Query q, Key k (both: d-dimensional)
|
v
+---------------------+
| Split into pairs |
| [q1,q2], [q3,q4]... |
+----------+----------+
|
v
+---------------------+
| Apply 2D rotation |
| at position m |
| |
| [q1'] = [cos(m*t1) -sin(m*t1)] [q1]
| [q2'] [sin(m*t1) cos(m*t1)] [q2]
| |
| theta_i = 10000^(-2i/d)
+----------+----------+
|
v
Rotated Query q'_m
(similarly for Key k'_n)
Position Interpolation (PI)¶
Original RoPE: positions [0, 1, 2, ..., L]
For context extension to 4L:
+---------------------------------------------+
| Linear Interpolation (PI) |
| |
| New positions: [0, 0.25, 0.5, 0.75, 1, ...]|
| |
| theta'(pos) = theta(pos / scale) |
| where scale = target_len / original_len |
+---------------------------------------------+
Problem: Compresses high-frequency components
YaRN (Yet another RoPE extensioN)¶
YaRN Strategy: Different treatment by frequency
RoPE Dimensions:
+------------+------------+------------+
| Low Freq | Medium | High Freq |
| (dim 0-16) | (dim 16-48)| (dim 48-64)|
+------------+------------+------------+
| | |
v v v
No scaling Interpolate NTK-aware
(preserve) (linear PI) (extend)
Formula:
theta'(d) = theta(d) * (
1 if d < d_low
(1-gamma) + gamma*s if d_low <= d < d_high
s if d >= d_high
)
where s = scale factor, gamma = interpolation weight
Sliding Window Attention¶
Full Attention (O(n^2)):
+---+---+---+---+---+---+---+---+
| 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
| 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
| 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
| 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
| 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
| 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
| 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
| 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
+---+---+---+---+---+---+---+---+
Sliding Window (O(n*w)):
+---+---+---+---+---+---+---+---+
| 1 | 1 | 1 | 0 | 0 | 0 | 0 | 0 |
| 1 | 1 | 1 | 1 | 0 | 0 | 0 | 0 |
| 1 | 1 | 1 | 1 | 1 | 0 | 0 | 0 |
| 0 | 1 | 1 | 1 | 1 | 1 | 0 | 0 |
| 0 | 0 | 1 | 1 | 1 | 1 | 1 | 0 |
| 0 | 0 | 0 | 1 | 1 | 1 | 1 | 1 |
| 0 | 0 | 0 | 0 | 1 | 1 | 1 | 1 |
| 0 | 0 | 0 | 0 | 0 | 1 | 1 | 1 |
+---+---+---+---+---+---+---+---+
w = window size (e.g., 4096)
Multi-Scale Attention (Longformer style)¶
Combining local and global attention:
+---+---+---+---+---+---+---+---+
| G | G | G | G | G | G | G | G | <- Global tokens
| G | L | L | L | 0 | 0 | 0 | 0 | <- Local window
| G | L | L | L | L | 0 | 0 | 0 | for each
| G | L | L | L | L | L | 0 | 0 | position
| G | 0 | L | L | L | L | L | 0 |
| G | 0 | 0 | L | L | L | L | L |
| G | 0 | 0 | 0 | L | L | L | L |
| G | 0 | 0 | 0 | 0 | L | L | L |
+---+---+---+---+---+---+---+---+
G = Global attention (to/from special tokens)
L = Local sliding window attention
0 = Masked (no attention)
Ring Attention (Distributed)¶
Ring Attention for very long sequences across devices:
Device 0 Device 1 Device 2
+--------+ +--------+ +--------+
| Q[0:n] | | Q[n:2n]| |Q[2n:3n]|
| K[0:n] | | K[n:2n]| |K[2n:3n]|
| V[0:n] | | V[n:2n]| |V[2n:3n]|
+---+----+ +---+----+ +---+----+
| | |
+-----------------+-----------------+
Ring Communication
+-----------------+-----------------+
| | |
v v v
Compute Compute Compute
Attention Attention Attention
(partial) (partial) (partial)
| | |
+-----------------+-----------------+
Accumulate
LongRoPE 아키텍처¶
LongRoPE: Two-stage extension
Stage 1: Search optimal rescale factors
+----------------------------------------+
| Input: Base model (e.g., 4K context) |
| Target: Extend to 256K |
| |
| Search space: |
| - Lambda factors for each RoPE dim |
| - Non-uniform interpolation |
+----------------------------------------+
|
v
Stage 2: Fine-tune with progressive extension
+----------------------------------------+
| 4K -> 64K -> 128K -> 256K -> 2M |
| |
| Short context: aggressive interpolation|
| Long context: conservative |
+----------------------------------------+
대표 기법 비교¶
| 기법 | 최대 컨텍스트 | Fine-tuning | 외삽 능력 |
|---|---|---|---|
| Sinusoidal | 학습 길이 | 필요 | 없음 |
| RoPE | 학습 길이 | 필요 | 제한적 |
| ALiBi | 무제한 | 불필요 | 우수 |
| PI (Position Interpolation) | ~4-8x | 필요 | 중간 |
| NTK-aware | ~4-8x | 필요 | 중간 |
| YaRN | ~16-32x | 필요 | 우수 |
| LongRoPE | 2M+ | 필요 | 매우 우수 |
대표 모델¶
| 모델 | 컨텍스트 길이 | 기법 |
|---|---|---|
| GPT-4 Turbo | 128K | 비공개 |
| Claude 3 | 200K | 비공개 |
| Gemini 1.5 Pro | 1M+ | 비공개 |
| Llama 3.1 | 128K | RoPE + 확장 |
| Mistral | 32K | Sliding Window |
| Yi-34B | 200K | YaRN |
| Command R | 128K | 비공개 |
| Jamba 1.5 | 256K | Hybrid (SSM) |
장단점¶
Position Interpolation (PI)¶
장점: - 간단한 구현 - 적은 fine-tuning으로 확장 가능
단점: - 고주파 정보 손실 - 확장 비율 제한 (4-8x)
ALiBi¶
장점: - Fine-tuning 없이 외삽 가능 - 구현 간단
단점: - 일부 태스크에서 성능 저하 - RoPE 대비 in-context learning 약함
YaRN¶
장점: - 큰 확장 비율 지원 - 품질 저하 최소화
단점: - 하이퍼파라미터 튜닝 필요 - Fine-tuning 필요
Sliding Window¶
장점: - O(n) 메모리로 긴 시퀀스 처리 - 로컬 컨텍스트에 효율적
단점: - 장거리 의존성 직접 참조 불가 - 전역 정보 손실 가능
코드 예시¶
RoPE 구현¶
import torch
import torch.nn as nn
import math
class RotaryPositionEmbedding(nn.Module):
def __init__(self, dim: int, max_seq_len: int = 8192, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# Compute inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
# Precompute rotary embeddings
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
positions = torch.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", positions, self.inv_freq)
# [seq_len, dim/2] -> [seq_len, dim]
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos())
self.register_buffer("sin_cached", emb.sin())
def forward(self, x, seq_len: int = None):
# x: (batch, n_heads, seq_len, head_dim)
if seq_len is None:
seq_len = x.shape[2]
return (
self.cos_cached[:seq_len],
self.sin_cached[:seq_len]
)
def rotate_half(x):
"""Rotate half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
"""Apply rotary position embedding to queries and keys."""
# cos, sin: (seq_len, dim)
# q, k: (batch, n_heads, seq_len, head_dim)
cos = cos.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, dim)
sin = sin.unsqueeze(0).unsqueeze(0)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Position Interpolation¶
class ScaledRotaryEmbedding(RotaryPositionEmbedding):
"""RoPE with Position Interpolation for context extension."""
def __init__(
self,
dim: int,
max_seq_len: int = 8192,
base: float = 10000.0,
scaling_factor: float = 1.0 # target_len / original_len
):
self.scaling_factor = scaling_factor
super().__init__(dim, max_seq_len, base)
def _build_cache(self, seq_len: int):
# Scale positions for interpolation
positions = torch.arange(seq_len, dtype=self.inv_freq.dtype)
positions = positions / self.scaling_factor # Key difference
freqs = torch.einsum("i,j->ij", positions, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos())
self.register_buffer("sin_cached", emb.sin())
YaRN 구현¶
class YaRNRotaryEmbedding(nn.Module):
"""
YaRN: Yet another RoPE extensioN
Combines:
- NTK-aware interpolation for high frequencies
- Linear interpolation for medium frequencies
- No scaling for low frequencies
"""
def __init__(
self,
dim: int,
max_seq_len: int = 8192,
base: float = 10000.0,
scale: float = 1.0,
original_max_seq_len: int = 4096,
beta_fast: float = 32.0,
beta_slow: float = 1.0,
):
super().__init__()
self.dim = dim
self.scale = scale
self.original_max_seq_len = original_max_seq_len
# Compute interpolation factors per dimension
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
# Find wavelength boundaries
low_freq_wavelen = original_max_seq_len / beta_slow
high_freq_wavelen = original_max_seq_len / beta_fast
wavelen = 2 * math.pi / inv_freq
# Compute per-dimension scaling
scaling_factors = torch.ones_like(inv_freq)
for i, wl in enumerate(wavelen):
if wl < high_freq_wavelen:
# High frequency: no scaling
scaling_factors[i] = 1.0
elif wl > low_freq_wavelen:
# Low frequency: full scaling
scaling_factors[i] = scale
else:
# Medium frequency: interpolate
smooth = (wl - high_freq_wavelen) / (low_freq_wavelen - high_freq_wavelen)
scaling_factors[i] = 1.0 + (scale - 1.0) * smooth
self.register_buffer("inv_freq", inv_freq / scaling_factors)
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
positions = torch.arange(seq_len, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", positions, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
# Apply YaRN attention scaling
mscale = 0.1 * math.log(self.scale) + 1.0
self.register_buffer("cos_cached", emb.cos() * mscale)
self.register_buffer("sin_cached", emb.sin() * mscale)
def forward(self, x, seq_len: int = None):
if seq_len is None:
seq_len = x.shape[2]
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
Sliding Window Attention¶
class SlidingWindowAttention(nn.Module):
"""Attention with sliding window for efficient long sequence processing."""
def __init__(
self,
d_model: int,
n_heads: int,
window_size: int = 4096,
dropout: float = 0.1
):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.window_size = window_size
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def _create_sliding_window_mask(self, seq_len: int, device: torch.device):
"""Create causal sliding window attention mask."""
# Start with causal mask
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
# Apply sliding window
for i in range(seq_len):
start = max(0, i - self.window_size + 1)
mask[i, :start] = 0
return mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
def forward(self, x):
batch_size, seq_len, _ = x.shape
Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# Apply sliding window mask
mask = self._create_sliding_window_mask(seq_len, x.device)
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
context = torch.matmul(attn_weights, V)
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return self.W_o(context)
ALiBi 구현¶
class ALiBiAttention(nn.Module):
"""Attention with Linear Biases for position encoding."""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
# Compute ALiBi slopes
self.register_buffer("slopes", self._get_alibi_slopes(n_heads))
def _get_alibi_slopes(self, n_heads: int):
"""Get ALiBi slopes for each attention head."""
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * (ratio ** i) for i in range(n)]
if math.log2(n_heads).is_integer():
return torch.tensor(get_slopes_power_of_2(n_heads))
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
slopes_1 = get_slopes_power_of_2(closest_power_of_2)
slopes_2 = get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:n_heads - closest_power_of_2]
return torch.tensor(slopes_1 + slopes_2)
def _get_alibi_bias(self, seq_len: int, device: torch.device):
"""Compute ALiBi bias matrix."""
# Distance matrix
positions = torch.arange(seq_len, device=device)
distance = positions.unsqueeze(0) - positions.unsqueeze(1) # (seq_len, seq_len)
distance = distance.abs().float()
# Apply slopes
alibi = distance.unsqueeze(0) * self.slopes.unsqueeze(-1).unsqueeze(-1).to(device)
return -alibi # Negative because we subtract from attention scores
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
# Attention scores with ALiBi bias
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
scores = scores + self._get_alibi_bias(seq_len, x.device)
# Causal mask
if mask is None:
mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
context = torch.matmul(attn_weights, V)
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
return self.W_o(context)
참고 논문¶
- Su, J., et al. (2021). "RoFormer: Enhanced Transformer with Rotary Position Embedding."
-
arXiv: https://arxiv.org/abs/2104.09864
-
Press, O., et al. (2021). "Train Short, Test Long: Attention with Linear Biases Enables Input Length Generalization." (ALiBi)
-
arXiv: https://arxiv.org/abs/2108.12409
-
Chen, S., et al. (2023). "Extending Context Window of Large Language Models via Positional Interpolation." (PI)
-
arXiv: https://arxiv.org/abs/2306.15595
-
Peng, B., et al. (2023). "YaRN: Efficient Context Window Extension of Large Language Models."
-
arXiv: https://arxiv.org/abs/2309.00071
-
Ding, Y., et al. (2024). "LongRoPE: Extending LLM Context Window Beyond 2 Million Tokens."
-
arXiv: https://arxiv.org/abs/2402.13753
-
Liu, H., et al. (2023). "Scaling Laws of RoPE-based Extrapolation."
-
arXiv: https://arxiv.org/abs/2310.05209
-
Dao, T. (2024). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning."
-
arXiv: https://arxiv.org/abs/2307.08691
-
Liu, Z., et al. (2023). "Ring Attention with Blockwise Transformers for Near-Infinite Context."
- arXiv: https://arxiv.org/abs/2310.01889