콘텐츠로 이동

Attention Mechanism

논문 정보

항목 내용
제목 Attention Is All You Need
저자 Ashish Vaswani et al. (Google)
학회 NeurIPS 2017
링크 https://arxiv.org/abs/1706.03762
관련 논문 내용
Bahdanau Attention (2014) Neural Machine Translation by Jointly Learning to Align and Translate
Luong Attention (2015) Effective Approaches to Attention-based Neural Machine Translation

개요

문제 정의

Seq2Seq 모델의 병목 (Bottleneck) 문제:

Encoder                          Decoder
[x_1, x_2, ..., x_n] ---> [c] ---> [y_1, y_2, ..., y_m]
                          ^
                 전체 정보를 하나의
                 고정 벡터로 압축
  • 긴 시퀀스의 정보 손실
  • 입력 시퀀스의 특정 부분에 집중 불가

핵심 아이디어

"모든 입력을 동일하게 보지 말고, 관련 있는 부분에 집중하자"

\[\text{Attention}(Q, K, V) = \sum_i \alpha_i V_i\]

여기서 \(\alpha_i\)는 Query와 Key의 관련성(유사도)에 기반한 가중치.

Attention의 진화

Bahdanau Attention (Additive)

         h_1    h_2    h_3    h_4
          │      │      │      │
          ▼      ▼      ▼      ▼
        ┌──────────────────────┐
        │   Alignment Model    │
        │  (Additive/MLP)      │
        └──────────────────────┘
         α_1   α_2   α_3   α_4    (attention weights)
          │      │      │      │
          └──────┴──────┴──────┘
            Context Vector c

Score 함수:

\[e_{ij} = v_a^T \tanh(W_a s_{i-1} + U_a h_j)\]
기호 설명
\(s_{i-1}\) 디코더의 이전 은닉 상태 (Query)
\(h_j\) 인코더의 j번째 은닉 상태 (Key = Value)
\(W_a, U_a, v_a\) 학습 가능한 파라미터

Attention 가중치:

\[\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^{n} \exp(e_{ik})}\]

Context Vector:

\[c_i = \sum_{j=1}^{n} \alpha_{ij} h_j\]

Luong Attention (Multiplicative)

더 단순하고 효율적인 점수 계산:

유형 Score 함수
Dot \(s_t^T h_s\)
General \(s_t^T W_a h_s\)
Concat \(v_a^T \tanh(W_a [s_t; h_s])\)

Dot Product가 가장 효율적이고 널리 사용됨.

Self-Attention

개념

자기 자신의 시퀀스 내에서 각 위치가 다른 모든 위치를 참조:

Input:  "The animal didn't cross the street because it was too tired"
                                              "it"이 무엇?

Self-Attention이 "it" -> "animal" 연결을 학습

Query, Key, Value

입력 시퀀스 \(X \in \mathbb{R}^{n \times d}\)에서:

\[Q = XW^Q, \quad K = XW^K, \quad V = XW^V\]
개념 역할 비유
Query (Q) 현재 위치에서 "무엇을 찾고 있는지" 검색어
Key (K) 각 위치가 "어떤 정보를 가지고 있는지" 문서 제목
Value (V) 실제로 전달할 정보 문서 내용

Scaled Dot-Product Attention

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]

Scaling 이유: \(d_k\)가 크면 dot product 값이 커져서 softmax가 극단적인 값으로 포화됨.

\[\text{Var}(q \cdot k) = d_k \cdot \text{Var}(q_i) \cdot \text{Var}(k_i) = d_k\]

\(\sqrt{d_k}\)로 나누면 분산이 1로 정규화됨.

계산 흐름

    Q (n×d_k)      K^T (d_k×n)         V (n×d_v)
        │               │                  │
        └───────┬───────┘                  │
                │                          │
                ▼                          │
            Q @ K^T                        │
           (n × n)                         │
                │                          │
                ▼                          │
            / sqrt(d_k)                    │
                │                          │
                ▼                          │
            Softmax                        │
           (n × n)                         │
                │                          │
                └──────────┬───────────────┘
                      @ V (n × d_v)
                      Output (n × d_v)

Multi-Head Attention

개념

단일 Attention 대신 여러 개의 Attention을 병렬로 수행:

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O\]
\[\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)\]

장점

장점 설명
다양한 관계 학습 각 head가 다른 관계에 집중
표현력 증가 여러 부분공간(subspace)에서 attention
안정적 학습 head 간 평균화 효과

예시

"The cat sat on the mat"

Head 1: 주어-동사 관계 학습 (cat -> sat)
Head 2: 명사-전치사 관계 학습 (mat -> on)
Head 3: 위치 관계 학습 (인접 토큰)
...

파라미터

파라미터 일반적 값 설명
\(d_{model}\) 512, 768, 1024 모델 차원
\(h\) 8, 12, 16 head 수
\(d_k = d_v\) \(d_{model}/h\) 각 head의 차원

PyTorch 구현

Scaled Dot-Product Attention

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Args:
        Q: (batch, ..., seq_len, d_k)
        K: (batch, ..., seq_len, d_k)
        V: (batch, ..., seq_len, d_v)
        mask: (batch, ..., seq_len, seq_len) or broadcastable
    Returns:
        output: (batch, ..., seq_len, d_v)
        attention_weights: (batch, ..., seq_len, seq_len)
    """
    d_k = Q.size(-1)

    # Attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # Masking (optional)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Softmax
    attention_weights = F.softmax(scores, dim=-1)

    # Weighted sum
    output = torch.matmul(attention_weights, V)

    return output, attention_weights

Multi-Head Attention

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        # Linear projections
        Q = self.W_q(Q)  # (batch, seq, d_model)
        K = self.W_k(K)
        V = self.W_v(V)

        # Split into heads: (batch, seq, d_model) -> (batch, heads, seq, d_k)
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Attention
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)

        # Concatenate heads: (batch, heads, seq, d_k) -> (batch, seq, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )

        # Final projection
        output = self.W_o(attn_output)

        return output, attn_weights

# 테스트
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(32, 100, 512)  # (batch, seq_len, d_model)
output, weights = mha(x, x, x)  # Self-attention
print(f"Output: {output.shape}")   # (32, 100, 512)
print(f"Weights: {weights.shape}") # (32, 8, 100, 100)

Attention Masking

def create_padding_mask(seq, pad_idx=0):
    """패딩 토큰 마스킹"""
    # seq: (batch, seq_len)
    # output: (batch, 1, 1, seq_len) for broadcasting
    return (seq != pad_idx).unsqueeze(1).unsqueeze(2)

def create_causal_mask(seq_len):
    """미래 토큰 마스킹 (decoder self-attention용)"""
    # output: (1, 1, seq_len, seq_len)
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    return (mask == 0).unsqueeze(0).unsqueeze(0)

# 예시
seq_len = 10
causal_mask = create_causal_mask(seq_len)
print(causal_mask.squeeze())
# tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#         [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
#         [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
#         ...

PyTorch 내장 MHA

# PyTorch 내장 MultiheadAttention
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)

x = torch.randn(32, 100, 512)

# Self-attention
output, attn_weights = mha(x, x, x)

# Cross-attention (encoder-decoder)
encoder_output = torch.randn(32, 50, 512)
decoder_input = torch.randn(32, 30, 512)
output, _ = mha(decoder_input, encoder_output, encoder_output)  # Q from decoder, K,V from encoder

Attention 변형

Cross-Attention

Encoder-Decoder 구조에서 사용:

class CrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)

    def forward(self, decoder_hidden, encoder_output, mask=None):
        # Q: decoder, K/V: encoder
        return self.mha(decoder_hidden, encoder_output, encoder_output, mask)

Efficient Attention 변형

변형 복잡도 설명
Vanilla O(n^2) 기본 Self-Attention
Sparse O(n sqrt(n)) 일부 위치만 attend
Linear O(n) Kernel trick 사용
Flash O(n^2) 메모리 O(n) IO-aware 최적화

Flash Attention

# PyTorch 2.0+ 내장
from torch.nn.functional import scaled_dot_product_attention

# 자동으로 Flash Attention 사용 (조건 충족 시)
output = scaled_dot_product_attention(Q, K, V, is_causal=True)

Attention 시각화

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(attention_weights, tokens_x, tokens_y=None):
    """
    attention_weights: (seq_len_q, seq_len_k)
    tokens_x: Key 토큰 리스트
    tokens_y: Query 토큰 리스트 (None이면 tokens_x 사용)
    """
    if tokens_y is None:
        tokens_y = tokens_x

    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attention_weights.detach().numpy(),
        xticklabels=tokens_x,
        yticklabels=tokens_y,
        cmap='Blues',
        annot=True,
        fmt='.2f'
    )
    plt.xlabel('Key')
    plt.ylabel('Query')
    plt.title('Attention Weights')
    plt.tight_layout()
    plt.savefig('attention_viz.png', dpi=150)
    plt.show()

# 예시
tokens = ['The', 'cat', 'sat', 'on', 'mat']
attn = torch.softmax(torch.randn(5, 5), dim=-1)
visualize_attention(attn, tokens)

Attention의 의미

학습되는 것

관계 유형 예시
문법적 주어-동사, 관사-명사
의미적 대명사-선행사, 동의어
위치적 인접 토큰 관계
장거리 문장 간 참조

연구에 따르면 각 head가 다른 언어적 관계에 특화:

Head 1: 바로 다음 토큰 주목 (positional)
Head 2: 동사-목적어 관계 (syntactic)
Head 3: 대명사 해소 (coreference)
Head 4: 부정어 범위 (semantic)

관련 문서

주제 링크
딥러닝 기초 README.md
RNN/LSTM rnn-lstm.md
Transformer ../../architecture/transformer.md

참고

  • Vaswani, A. et al. (2017). "Attention Is All You Need"
  • Bahdanau, D. et al. (2014). "Neural Machine Translation by Jointly Learning to Align and Translate"
  • The Illustrated Transformer: https://jalammar.github.io/illustrated-transformer/
  • Attention? Attention! (Lilian Weng): https://lilianweng.github.io/posts/2018-06-24-attention/