본문 바로가기

LLM

FlashAttention 완전 정리 — LLM이 긴 문서를 처리할 수 있는 진짜 이유

반응형

LLM이 긴 문서를 처리할 때 왜 느려질까요?

GPT-3 컨텍스트: 2,048 토큰
GPT-4 컨텍스트: 128,000 토큰
Llama 3: 1,000,000 토큰

2년 만에 500배 늘어났어요.
이게 가능해진 핵심 기술이 FlashAttention이에요.

Attention이 뭔가

LLM은 텍스트를 읽을 때 모든 토큰이 다른 모든 토큰과 얼마나 관련있는지 계산해요.

입력: "나는 사과를 먹었다"

각 토큰이 다른 토큰과의 관계 점수 계산:
"나는" ↔ "사과를": 0.3
"나는" ↔ "먹었다": 0.8
"사과를" ↔ "먹었다": 0.7
...

수식으로는 이래요.

Attention(Q, K, V) = softmax(QK^T / √d_k) × V

Q(Query):  "지금 처리 중인 토큰"
K(Key):    "비교할 모든 토큰"
V(Value):  "실제 정보"

QK^T:      모든 토큰 쌍의 유사도 행렬
softmax:   유사도를 확률로 변환
× V:       유사도에 따라 정보 가중합산

문제 — Attention은 메모리 병목이다

GPU는 두 종류의 메모리가 있어요.

HBM (High Bandwidth Memory, 느린 메모리):
→ 용량 큼 (H100: 80GB)
→ 속도 느림 (대역폭 3.35TB/s)
→ 모델 가중치, KV Cache 저장

SRAM (빠른 메모리, L2 Cache):
→ 용량 매우 작음 (H100: 50MB)
→ 속도 매우 빠름 (대역폭 ~33TB/s, 10배 빠름)
→ 현재 계산 중인 데이터 저장

기존 Attention 계산 과정이에요.

1. Q, K를 HBM에서 SRAM으로 읽기
2. QK^T 행렬 계산 (N×N 크기)
3. 결과를 HBM에 저장 ← 느린 메모리에 씀!
4. HBM에서 다시 읽어서 Softmax 계산
5. 결과를 HBM에 저장 ← 또 씀!
6. HBM에서 읽어서 × V 계산
7. 최종 결과 HBM에 저장

HBM ↔ SRAM 왔다갔다를 6번이나 해요.

진짜 문제:

토큰 수 N이 늘어날수록:
중간 행렬(QK^T) 크기: N × N
→ 10K 토큰: 10,000 × 10,000 = 1억 개 값 저장
→ 100K 토큰: 100억 개 값 저장 → SRAM에 못 들어감!
→ HBM에 써야 함 → 엄청나게 느려짐

메모리 사용량이 O(N²)으로 늘어요. 컨텍스트 길이에 제곱으로 비례해서 느려져요.


FlashAttention의 핵심 아이디어 — Tiling

중간 행렬을 HBM에 저장하지 않고 SRAM 안에서 처리하는 거예요.

기본 아이디어:

전체 행렬을 한번에 계산하는 대신
작은 블록(타일) 단위로 나눠서 SRAM 안에서 처리

예시로 설명하면 이래요.

기존 방식 (10페이지 문서 요약):
1. 10페이지 전체를 화이트보드에 복사
   (HBM에 N×N 행렬 저장)
2. 모든 문장 쌍을 비교
3. 요약 작성

FlashAttention 방식:
1. 1페이지씩 쪽지(SRAM)에 메모
2. 비교하고 중간 결과 업데이트
3. 쪽지 내용 지우고 다음 페이지
4. 화이트보드(HBM) 사용 최소화

실제 알고리즘:

for 블록 i in Q:          # Q를 블록으로 나눔
    for 블록 j in K, V:   # K, V를 블록으로 나눔
        # SRAM에서 계산
        S_ij = Q_i × K_j^T      # 유사도
        P_ij = softmax(S_ij)    # 확률
        O_i += P_ij × V_j       # 가중합

# HBM 접근: 처음 Q,K,V 읽기 + 최종 O 쓰기만!
# 중간 행렬 HBM 저장 없음

이게 되려면 Softmax를 블록 단위로 계산해야 하는데 이게 수학적으로 쉽지 않아요.

Online Softmax가 이걸 해결했어요.

일반 Softmax:
전체 값을 다 봐야 분모 계산 가능
→ N×N 행렬 전체 필요

Online Softmax:
블록 하나 볼 때마다 최댓값 추적
이전 블록 결과를 새 정보로 점진적 업데이트
→ 전체 행렬 없이도 정확한 결과 가능!

FlashAttention 버전별 발전

FlashAttention 1 (2022)

개선:
- HBM 접근 횟수 대폭 감소
- 메모리: O(N²) → O(N)
- 속도: 2~4배 향상
- 정확도: 기존과 동일 (근사 아님)

한계:
- GPU 사용률 35% 수준
- 배치 크기와 헤드 수 기준으로만 병렬화

FlashAttention 2 (2023)

개선:
- 시퀀스 길이 방향으로도 병렬화
- GPU 사용률 70%로 향상
- FA1 대비 2배 빠름

특히 좋은 점:
- 추론에서 배치 크기 작고 시퀀스 길 때 효율적
- Ampere (A100), Ada (RTX 4090) GPU 최적화

FlashAttention 3 (2024)

개선:
- H100 GPU 특성 완전 활용
- Producer-Consumer 비동기 처리
- Tensor Core와 데이터 이동 겹쳐서 실행
- FP8 지원으로 정확도 유지하면서 더 빠름

결과:
FA2 대비 1.5~2.0배 빠름
H100에서 740 TFLOPS (이론 최대의 75%)
FP8: 1.2 PFLOPS

실제로 얼마나 빨라지나

컨텍스트 길이별 속도 향상 (A100 기준):

2K 토큰:  FA 2.0x 빠름
8K 토큰:  FA 4.0x 빠름
32K 토큰: FA 6.0x 빠름
128K 토큰: FA 기본 Attention 불가능 → FA로만 가능

메모리:
2K:   기본 1.6GB → FA 0.2GB
8K:   기본 25GB → FA 0.8GB (기본은 OOM!)
32K:  기본 400GB → FA 3GB

FlashAttention 없으면 긴 컨텍스트 자체가 불가능해요.


SGLang에서 FlashAttention 활용

SGLang은 FA3를 기본으로 사용해요. 별도 설정 없이 자동으로 켜져 있어요.

# FA3 자동 적용 (기본값)
python -m sglang.launch_server \
  --model-path meta-llama/Llama-3.3-70B-Instruct \
  --port 30000

# FA 버전 확인
python -c "
import sglang
print(sglang.__version__)
# 서버 시작 로그에서 확인:
# INFO: Using FlashAttention-3 backend
"

긴 컨텍스트 서빙:

# 1M 컨텍스트 서빙 (FA 없으면 불가능)
python -m sglang.launch_server \
  --model-path Qwen/Qwen3.5-9B-Instruct \
  --context-length 1000000 \
  --port 30000

Python에서 직접 사용하기

PyTorch 2.x (자동 적용):

import torch
import torch.nn.functional as F

# PyTorch 2.x에서 F.scaled_dot_product_attention은
# 자동으로 FlashAttention 사용
q = torch.randn(1, 8, 4096, 64).cuda().half()
k = torch.randn(1, 8, 4096, 64).cuda().half()
v = torch.randn(1, 8, 4096, 64).cuda().half()

# FlashAttention 자동 사용
output = F.scaled_dot_product_attention(q, k, v)

Hugging Face Transformers:

from transformers import AutoModelForCausalLM

# FA2 명시적 사용
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.3-70B-Instruct",
    attn_implementation="flash_attention_2",  # FA2
    torch_dtype=torch.float16,
    device_map="auto"
)

# FA3 (H100만)
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.3-70B-Instruct",
    attn_implementation="flash_attention_3",  # FA3
    torch_dtype=torch.float16,
    device_map="auto"
)

직접 설치:

pip install flash-attn --no-build-isolation
# FA3 (H100만)
pip install flash-attn-3 --no-build-isolation

LLM 추론 최적화 전체 그림

FlashAttention이 전체 최적화 스택의 어디에 위치하는지 정리해요.

레이어 1: 모델 레벨
- 양자화 (INT4/INT8): 모델 크기 50% 감소
- 프루닝: 불필요한 가중치 제거
- 지식 증류: 작은 모델로 압축

레이어 2: 시스템 레벨
- FlashAttention: Attention 계산 자체를 빠르게 ← 여기
- KV Cache: 이전 계산 저장 재사용
- PagedAttention: KV Cache 메모리 효율화
- Continuous Batching: 여러 요청 동시 처리

레이어 3: 애플리케이션 레벨
- 프롬프트 캐싱: 반복 프롬프트 재사용
- 컨텍스트 압축: 불필요한 토큰 제거
- 라우팅: 모델 선택 최적화

세 레이어 전부 적용하면 비용 80% 이상 절감 가능해요.


 

반응형