콘텐츠로 이동

Mixture of Experts (MoE)

개요

Mixture of Experts (MoE)는 모델 파라미터를 크게 늘리면서도 계산 비용은 일정하게 유지하는 희소 활성화(sparse activation) 기법이다. 입력에 따라 일부 "전문가(expert)"만 활성화하여 효율적인 스케일링을 달성한다.

핵심 개념

희소 활성화 (Sparse Activation)

전통적인 Dense 모델은 모든 파라미터가 매 입력에 사용되지만, MoE는 라우터(router)가 선택한 일부 전문가만 활성화된다.

Dense Model:    All parameters activated for every input
MoE Model:      Only K experts (out of N) activated per token

라우터 (Router/Gating)

각 토큰을 어떤 전문가에게 보낼지 결정하는 메커니즘.

기본 수식:

G(x) = softmax(W_g * x)          # Gating scores
TopK(G(x))                        # Select top K experts
y = sum_{i in TopK} G(x)_i * E_i(x)  # Weighted sum of expert outputs

로드 밸런싱

특정 전문가에게 토큰이 집중되는 것을 방지하기 위한 보조 손실:

L_balance = alpha * N * sum_i(f_i * P_i)

where:
  f_i = fraction of tokens routed to expert i
  P_i = average routing probability to expert i
  N = number of experts
  alpha = balance coefficient

아키텍처 다이어그램

MoE Layer 구조

                         Input Token
                              |
                              v
                    +-------------------+
                    |      Router       |
                    | (Linear + Softmax)|
                    +-------------------+
                              |
              +-------+-------+-------+-------+
              |       |       |       |       |
        [0.4] v  [0.1]v  [0.3]v  [0.2]v       |
        +-----+  +-----+  +-----+  +-----+    |
        | E_0 |  | E_1 |  | E_2 |  | E_3 |    |
        +-----+  +-----+  +-----+  +-----+    |
        (FFN)    (FFN)    (FFN)    (FFN)      |
           |        X        |        X       |
           |     (not       |     (not       |
           |   selected)    |   selected)    |
           v                v                 |
        +-----+          +-----+              |
        |0.4*y|          |0.3*y|              |
        +-----+          +-----+              |
              \          /                    |
               \        /                     |
                v      v                      |
              +---------+                     |
              |   Sum   | <-------------------+
              +---------+      (residual)
                    |
                    v
                 Output

Transformer + MoE 통합

              +---------------------------+
              |     Transformer Block     |
              |                           |
              |  +---------------------+  |
              |  |   Self-Attention    |  |
              |  +----------+----------+  |
              |             |             |
              |  +----------v----------+  |
              |  |   Layer Norm        |  |
              |  +----------+----------+  |
              |             |             |
              |  +----------v----------+  |
              |  |     MoE Layer       |  |  <-- Replace FFN
              |  |  +--------------+   |  |
              |  |  |   Router     |   |  |
              |  |  +------+-------+   |  |
              |  |         |           |  |
              |  |   [E0][E1]...[En]   |  |
              |  |     (Experts)       |  |
              |  +---------+-----------+  |
              |            |              |
              |  +---------v-----------+  |
              |  |    Layer Norm       |  |
              |  +---------------------+  |
              +---------------------------+
                        x N layers

DeepSeek MoE 구조 (Fine-grained + Shared Experts)

                    Input Token
                         |
                         v
                   +-----------+
                   |   Router  |
                   +-----+-----+
                         |
         +---------------+----------------+
         |               |                |
         v               v                v
   +-----------+   +-----------+    +-----------+
   |  Shared   |   |  Routed   |    |  Routed   |
   |  Expert   |   | Expert 1  |    | Expert K  |
   | (Always   |   | (if       |    | (if       |
   |  Active)  |   | selected) |    | selected) |
   +-----------+   +-----------+    +-----------+
         |               |                |
         +---------------+----------------+
                         |
                         v
                   [Combine & Sum]
                         |
                         v
                      Output

대표 모델

모델 총 파라미터 활성 파라미터 전문가 수 활성화 수
Mixtral 8x7B 46.7B ~13B 8 2
Mixtral 8x22B 176B ~39B 8 2
DeepSeek-V2 236B 21B 160 6
DeepSeek-V3 671B 37B 256 8
Grok-1 314B ~86B 8 2
Switch Transformer 1.6T ~1.6B 2048 1
Llama 4 Scout 109B 17B 16 1
Llama 4 Maverick 400B 17B 128 1

장단점

장점

  1. 효율적 스케일링: 총 파라미터 대비 낮은 계산 비용
  2. 전문화: 각 전문가가 특정 패턴/도메인에 특화
  3. 학습 속도: Dense 모델 대비 빠른 수렴
  4. 유연한 용량: 전문가 수 조절로 용량 확장 용이

단점

  1. 메모리 사용량: 모든 전문가를 메모리에 로드 필요
  2. 로드 불균형: 특정 전문가에 토큰 집중 가능성
  3. 통신 오버헤드: 분산 학습 시 expert 간 데이터 이동
  4. 학습 불안정성: 라우터 학습이 불안정할 수 있음

코드 예시

기본 MoE Layer 구현

import torch
import torch.nn as nn
import torch.nn.functional as F

class Expert(nn.Module):
    """Single expert (typically an FFN)"""
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Router(nn.Module):
    """Top-K router for selecting experts"""
    def __init__(self, d_model: int, num_experts: int, top_k: int = 2):
        super().__init__()
        self.top_k = top_k
        self.gate = nn.Linear(d_model, num_experts, bias=False)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        logits = self.gate(x)  # (batch, seq_len, num_experts)

        # Top-K selection
        top_k_logits, top_k_indices = torch.topk(logits, self.top_k, dim=-1)
        top_k_weights = F.softmax(top_k_logits, dim=-1)

        return top_k_weights, top_k_indices, logits


class MoELayer(nn.Module):
    """Mixture of Experts layer"""
    def __init__(
        self,
        d_model: int,
        d_ff: int,
        num_experts: int,
        top_k: int = 2,
        dropout: float = 0.1
    ):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        self.router = Router(d_model, num_experts, top_k)
        self.experts = nn.ModuleList([
            Expert(d_model, d_ff, dropout) for _ in range(num_experts)
        ])

    def forward(self, x):
        batch_size, seq_len, d_model = x.shape

        # Get routing weights and indices
        weights, indices, router_logits = self.router(x)
        # weights: (batch, seq_len, top_k)
        # indices: (batch, seq_len, top_k)

        # Flatten for processing
        x_flat = x.view(-1, d_model)  # (batch * seq_len, d_model)
        weights_flat = weights.view(-1, self.top_k)
        indices_flat = indices.view(-1, self.top_k)

        # Initialize output
        output = torch.zeros_like(x_flat)

        # Process each expert
        for i, expert in enumerate(self.experts):
            # Find tokens routed to this expert
            mask = (indices_flat == i).any(dim=-1)
            if mask.sum() == 0:
                continue

            # Get expert output
            expert_input = x_flat[mask]
            expert_output = expert(expert_input)

            # Weight by routing score
            expert_weights = weights_flat[mask]
            expert_indices = indices_flat[mask]

            # Find which top_k slot this expert corresponds to
            slot_mask = (expert_indices == i).float()
            combined_weight = (expert_weights * slot_mask).sum(dim=-1, keepdim=True)

            output[mask] += combined_weight * expert_output

        output = output.view(batch_size, seq_len, d_model)

        # Return output and router logits (for aux loss)
        return output, router_logits

    def load_balance_loss(self, router_logits):
        """Auxiliary loss for load balancing"""
        # router_logits: (batch, seq_len, num_experts)
        num_tokens = router_logits.shape[0] * router_logits.shape[1]

        # Routing probabilities
        routing_probs = F.softmax(router_logits, dim=-1)
        avg_probs = routing_probs.mean(dim=[0, 1])  # (num_experts,)

        # Fraction of tokens per expert
        expert_indices = router_logits.argmax(dim=-1)  # (batch, seq_len)
        expert_mask = F.one_hot(expert_indices, self.num_experts).float()
        tokens_per_expert = expert_mask.sum(dim=[0, 1]) / num_tokens

        # Balance loss
        loss = self.num_experts * (tokens_per_expert * avg_probs).sum()
        return loss

DeepSeek 스타일 Fine-grained MoE

class DeepSeekMoE(nn.Module):
    """
    DeepSeek-style MoE with:
    - Fine-grained experts (more experts, smaller each)
    - Shared experts (always active)
    """
    def __init__(
        self,
        d_model: int,
        d_ff: int,
        num_routed_experts: int,
        num_shared_experts: int,
        top_k: int,
        dropout: float = 0.1
    ):
        super().__init__()
        self.num_routed = num_routed_experts
        self.num_shared = num_shared_experts
        self.top_k = top_k

        # Shared experts (always active)
        self.shared_experts = nn.ModuleList([
            Expert(d_model, d_ff // num_shared_experts, dropout)
            for _ in range(num_shared_experts)
        ])

        # Routed experts
        expert_ff_dim = d_ff // num_routed_experts
        self.routed_experts = nn.ModuleList([
            Expert(d_model, expert_ff_dim, dropout)
            for _ in range(num_routed_experts)
        ])

        self.router = Router(d_model, num_routed_experts, top_k)

    def forward(self, x):
        # Shared expert output (always computed)
        shared_output = sum(exp(x) for exp in self.shared_experts)

        # Routed expert output
        weights, indices, router_logits = self.router(x)
        routed_output = self._route_tokens(x, weights, indices)

        return shared_output + routed_output, router_logits

    def _route_tokens(self, x, weights, indices):
        # Similar to MoELayer implementation
        batch_size, seq_len, d_model = x.shape
        output = torch.zeros_like(x)

        for i, expert in enumerate(self.routed_experts):
            mask = (indices == i).any(dim=-1)
            if mask.sum() == 0:
                continue

            expert_input = x[mask]
            expert_output = expert(expert_input)

            expert_weights = weights[mask]
            expert_indices = indices[mask]
            slot_mask = (expert_indices == i).float()
            combined_weight = (expert_weights * slot_mask).sum(dim=-1, keepdim=True)

            output[mask] += combined_weight * expert_output

        return output

MoE 변형 기법

기법 설명 모델 예시
Top-K Routing K개 전문가 선택 Mixtral, GPT-4
Expert Choice 전문가가 토큰 선택 Switch Transformer
Soft MoE 모든 전문가 가중합 Soft MoE (2023)
Fine-grained 작은 전문가 다수 DeepSeek
Shared Experts 일부 전문가 항상 활성 DeepSeek

참고 논문

  1. Shazeer, N., et al. (2017). "Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer."
  2. arXiv: https://arxiv.org/abs/1701.06538

  3. Fedus, W., et al. (2022). "Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity."

  4. arXiv: https://arxiv.org/abs/2101.03961

  5. Lepikhin, D., et al. (2020). "GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding."

  6. arXiv: https://arxiv.org/abs/2006.16668

  7. Jiang, A., et al. (2024). "Mixtral of Experts."

  8. arXiv: https://arxiv.org/abs/2401.04088

  9. DeepSeek-AI. (2024). "DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Models."

  10. arXiv: https://arxiv.org/abs/2401.06066

  11. DeepSeek-AI. (2024). "DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model."

  12. arXiv: https://arxiv.org/abs/2405.04434

  13. Puigcerver, J., et al. (2023). "From Sparse to Soft Mixtures of Experts."

  14. arXiv: https://arxiv.org/abs/2308.00951