콘텐츠로 이동

Hybrid Architecture

개요

하이브리드 아키텍처는 Transformer의 어텐션 메커니즘과 SSM(State Space Model)의 효율성을 결합한 구조다. Transformer의 강력한 in-context learning 능력과 SSM의 선형 시간 복잡도를 모두 활용하여 긴 시퀀스를 효율적으로 처리한다.

핵심 개념

왜 하이브리드인가?

Transformer와 SSM은 상호 보완적인 특성을 가진다:

+--------------------+------------------+--------------------+
| 특성               | Transformer      | SSM (Mamba)        |
+--------------------+------------------+--------------------+
| 시간 복잡도        | O(n^2)           | O(n)               |
| 메모리 (추론)      | O(n) KV cache    | O(1) fixed state   |
| In-context Learning| 강력             | 제한적             |
| 장거리 의존성      | 직접 참조        | 상태 압축          |
| 정보 검색          | 정확             | 근사               |
+--------------------+------------------+--------------------+

하이브리드는 두 아키텍처의 장점을 결합: - 대부분의 레이어: SSM으로 효율적 처리 - 일부 레이어: Attention으로 정확한 정보 검색

블록 배치 전략

Interleaved (교차 배치):

[Mamba] -> [Mamba] -> [Mamba] -> [Attention] -> [Mamba] -> ...

Ratio-based (비율 기반):

Jamba: 1 Attention layer per 8 total layers (1:7 ratio)

Task-adaptive (태스크 적응형):

다른 태스크에 따라 Attention/Mamba 비율 조절

아키텍처 다이어그램

Jamba 기본 블록 구조

                     Input
                        |
                        v
    +-------------------+--------------------+
    |              Jamba Block               |
    |                                        |
    |   +--------+    or    +------------+   |
    |   | Mamba  |          | Attention  |   |
    |   | Layer  |          | Layer      |   |
    |   +---+----+          +-----+------+   |
    |       |                     |          |
    |       +----------+----------+          |
    |                  |                     |
    |       +----------v----------+          |
    |       |      MoE / MLP      |          |
    |       |   (Feed Forward)    |          |
    |       +----------+----------+          |
    |                  |                     |
    +------------------+---------------------+
                       |
                       v
                    Output

Jamba 전체 구조 (52B 모델 예시)

    +--------------------------------------------------+
    |                   Jamba Model                    |
    +--------------------------------------------------+
    |                                                  |
    |  Layer 1-7:   [Mamba] [Mamba] ... [Mamba] (7x)   |
    |  Layer 8:     [Attention + MoE]                  |
    |                                                  |
    |  Layer 9-15:  [Mamba] [Mamba] ... [Mamba] (7x)   |
    |  Layer 16:    [Attention + MoE]                  |
    |                                                  |
    |  ...                                             |
    |                                                  |
    |  Total: 72 layers                                |
    |  - 64 Mamba layers                               |
    |  - 8 Attention + MoE layers                      |
    |                                                  |
    +--------------------------------------------------+

    Layer ratio: 1:7 (Attention:Mamba)

Jamba 1.5 상세 구조

                         [Input Tokens]
                               |
                               v
                      +----------------+
                      |   Embedding    |
                      +-------+--------+
                              |
            +==========================================+
            |           REPEATED BLOCK (x9)            |
            |                                          |
            |  +------------------------------------+  |
            |  |         MAMBA BLOCK (x7)           |  |
            |  |  +------------------------------+  |  |
            |  |  |    LayerNorm                 |  |  |
            |  |  +-------------+----------------+  |  |
            |  |                |                   |  |
            |  |  +-------------v----------------+  |  |
            |  |  |      Mamba Layer             |  |  |
            |  |  |  (Selective SSM)             |  |  |
            |  |  +-------------+----------------+  |  |
            |  |                |                   |  |
            |  |  +-------------v----------------+  |  |
            |  |  |       MLP (Dense)            |  |  |
            |  |  +------------------------------+  |  |
            |  +------------------------------------+  |
            |                                          |
            |  +------------------------------------+  |
            |  |      ATTENTION + MoE BLOCK (x1)   |  |
            |  |  +------------------------------+  |  |
            |  |  |    LayerNorm                 |  |  |
            |  |  +-------------+----------------+  |  |
            |  |                |                   |  |
            |  |  +-------------v----------------+  |  |
            |  |  | Grouped-Query Attention      |  |  |
            |  |  | (GQA, 8 KV heads)            |  |  |
            |  |  +-------------+----------------+  |  |
            |  |                |                   |  |
            |  |  +-------------v----------------+  |  |
            |  |  |    MoE Layer                 |  |  |
            |  |  | (16 experts, top-2)          |  |  |
            |  |  +------------------------------+  |  |
            |  +------------------------------------+  |
            |                                          |
            +==========================================+
                              |
                              v
                      +----------------+
                      |   LayerNorm    |
                      +-------+--------+
                              |
                              v
                      +----------------+
                      |    LM Head     |
                      +----------------+

Zamba 구조 (Shared Attention)

    Zamba: 공유 어텐션 블록 사용

    +----------------------------------------------------+
    |                                                    |
    |  Block 1:  [Mamba] -> [Shared Attention] -> [MLP]  |
    |  Block 2:  [Mamba] -> [Shared Attention] -> [MLP]  |
    |  Block 3:  [Mamba] -> [Shared Attention] -> [MLP]  |
    |  ...                                               |
    |                                                    |
    |  * Shared Attention: 동일한 가중치를 모든 블록에서 공유  |
    |  * 메모리 효율성 향상                                  |
    |                                                    |
    +----------------------------------------------------+

StripedHyena 구조

    StripedHyena: 교대 배치 (Alternating)

    +--------------------------------------------+
    |  Layer 1:  [Hyena (Long Conv)]             |
    |  Layer 2:  [Attention]                     |
    |  Layer 3:  [Hyena (Long Conv)]             |
    |  Layer 4:  [Attention]                     |
    |  ...                                       |
    |                                            |
    |  * 1:1 비율로 교대 배치                      |
    |  * Hyena: 암시적 긴 합성곱                   |
    +--------------------------------------------+

대표 모델

모델 총 파라미터 활성 파라미터 구조 컨텍스트
Jamba 1.5 Large 398B 94B Mamba + Attention + MoE 256K
Jamba 1.5 Mini 52B 12B Mamba + Attention + MoE 256K
Jamba (원본) 52B 12B Mamba + Attention + MoE 256K
Zamba 7B 7.2B 7.2B Mamba + Shared Attention 4K
StripedHyena 7B 7B Hyena + Attention 128K

장단점

장점

  1. 효율적 긴 컨텍스트: SSM의 선형 복잡도로 긴 시퀀스 처리
  2. 강력한 검색 능력: Attention 레이어로 정확한 정보 참조
  3. 메모리 효율성: 대부분 SSM이므로 KV 캐시 크기 감소
  4. 유연한 설계: 태스크에 따라 Attention/Mamba 비율 조절 가능

단점

  1. 설계 복잡성: 최적 비율/배치 찾기 어려움
  2. 학습 난이도: 두 아키텍처의 상호작용 이해 필요
  3. 초기 단계: 아직 연구가 활발히 진행 중
  4. 툴체인 미성숙: Transformer 전용 최적화 활용 제한

코드 예시

Jamba 스타일 하이브리드 블록

import torch
import torch.nn as nn
from typing import Optional

# Assume MambaBlock and MultiHeadAttention are defined elsewhere
# from ssm import MambaBlock
# from transformer import MultiHeadAttention

class HybridBlock(nn.Module):
    """Single hybrid block that can be either Mamba or Attention"""
    def __init__(
        self,
        d_model: int,
        use_attention: bool = False,
        n_heads: int = 8,
        d_ff: int = None,
        state_dim: int = 16,
        dropout: float = 0.1
    ):
        super().__init__()
        self.use_attention = use_attention
        d_ff = d_ff or d_model * 4

        self.norm1 = nn.LayerNorm(d_model)

        if use_attention:
            self.layer = MultiHeadAttention(d_model, n_heads, dropout)
        else:
            self.layer = MambaBlock(d_model, state_dim=state_dim)

        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask: Optional[torch.Tensor] = None):
        # Pre-norm architecture
        if self.use_attention:
            x = x + self.layer(self.norm1(x), mask)
        else:
            x = x + self.layer(self.norm1(x))

        x = x + self.ffn(self.norm2(x))
        return x


class JambaModel(nn.Module):
    """
    Jamba-style hybrid model.

    Uses a 1:7 ratio of Attention:Mamba layers.
    Every 8th layer uses Attention + MoE.
    """
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 4096,
        n_layers: int = 72,
        n_heads: int = 32,
        d_ff: int = 14336,
        attention_frequency: int = 8,  # 1 attention per 8 layers
        use_moe: bool = True,
        num_experts: int = 16,
        top_k_experts: int = 2,
        state_dim: int = 16,
        dropout: float = 0.1
    ):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, d_model)

        self.layers = nn.ModuleList()
        for i in range(n_layers):
            use_attention = ((i + 1) % attention_frequency == 0)

            if use_attention and use_moe:
                # Attention + MoE layer
                layer = HybridAttentionMoEBlock(
                    d_model, n_heads, d_ff,
                    num_experts, top_k_experts, dropout
                )
            elif use_attention:
                # Attention + Dense FFN
                layer = HybridBlock(
                    d_model, use_attention=True,
                    n_heads=n_heads, d_ff=d_ff, dropout=dropout
                )
            else:
                # Mamba + Dense FFN
                layer = HybridBlock(
                    d_model, use_attention=False,
                    state_dim=state_dim, d_ff=d_ff, dropout=dropout
                )

            self.layers.append(layer)

        self.norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids, attention_mask=None):
        x = self.embedding(input_ids)

        for layer in self.layers:
            x = layer(x, attention_mask)

        x = self.norm(x)
        logits = self.lm_head(x)

        return logits


class HybridAttentionMoEBlock(nn.Module):
    """Attention block with MoE FFN (for every 8th layer in Jamba)"""
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        d_ff: int,
        num_experts: int,
        top_k: int,
        dropout: float = 0.1
    ):
        super().__init__()

        self.norm1 = nn.LayerNorm(d_model)
        self.attention = MultiHeadAttention(d_model, n_heads, dropout)

        self.norm2 = nn.LayerNorm(d_model)
        self.moe = MoELayer(d_model, d_ff, num_experts, top_k, dropout)

    def forward(self, x, mask=None):
        x = x + self.attention(self.norm1(x), mask)

        moe_out, router_logits = self.moe(self.norm2(x))
        x = x + moe_out

        return x

Zamba 스타일 (공유 어텐션)

class ZambaModel(nn.Module):
    """
    Zamba-style model with shared attention.

    All blocks share the same attention layer weights,
    reducing memory footprint significantly.
    """
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 3072,
        n_layers: int = 76,
        n_heads: int = 24,
        state_dim: int = 64,
        dropout: float = 0.1
    ):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, d_model)

        # Shared attention (same weights for all layers)
        self.shared_attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.attn_norm = nn.LayerNorm(d_model)

        # Individual Mamba blocks
        self.mamba_blocks = nn.ModuleList([
            MambaBlock(d_model, state_dim=state_dim)
            for _ in range(n_layers)
        ])
        self.mamba_norms = nn.ModuleList([
            nn.LayerNorm(d_model) for _ in range(n_layers)
        ])

        # Individual MLPs
        self.mlps = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_model * 4),
                nn.GELU(),
                nn.Linear(d_model * 4, d_model),
                nn.Dropout(dropout)
            )
            for _ in range(n_layers)
        ])
        self.mlp_norms = nn.ModuleList([
            nn.LayerNorm(d_model) for _ in range(n_layers)
        ])

        self.final_norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids, attention_mask=None):
        x = self.embedding(input_ids)

        for i in range(len(self.mamba_blocks)):
            # Mamba processing
            x = x + self.mamba_blocks[i](self.mamba_norms[i](x))

            # Shared attention
            x = x + self.shared_attention(self.attn_norm(x), attention_mask)

            # MLP
            x = x + self.mlps[i](self.mlp_norms[i](x))

        x = self.final_norm(x)
        return self.lm_head(x)

설계 고려사항

Attention/Mamba 비율 선택

비율 특징 적합한 경우
1:7 (Jamba) 효율성 우선 긴 컨텍스트, 추론 비용 민감
1:3 균형 일반적인 사용
1:1 검색 능력 우선 In-context learning 중요

MoE 통합 전략

  • 모든 레이어: 파라미터 효율적이나 학습 어려움
  • Attention 레이어만: Jamba 접근법, 안정적
  • Mamba 레이어만: 실험적, 연구 중

참고 논문

  1. Lieber, O., et al. (2024). "Jamba: A Hybrid Transformer-Mamba Language Model."
  2. arXiv: https://arxiv.org/abs/2403.19887

  3. AI21 Labs. (2024). "Jamba 1.5: Hybrid Transformer-Mamba Models at Scale."

  4. https://www.ai21.com/blog/announcing-jamba-1-5/

  5. Glorioso, P., et al. (2024). "Zamba: A Compact 7B SSM Hybrid Model."

  6. arXiv: https://arxiv.org/abs/2405.18712

  7. Poli, M., et al. (2023). "Hyena Hierarchy: Towards Larger Convolutional Language Models."

  8. arXiv: https://arxiv.org/abs/2302.10866

  9. Together AI. (2023). "StripedHyena: Moving Beyond Transformers with Hybrid Signal Processing Models."

  10. https://www.together.ai/blog/stripedhyena-7b

  11. Waleffe, R., et al. (2024). "An Empirical Study of Mamba-based Language Models."

  12. arXiv: https://arxiv.org/abs/2406.07887