반응형
LLM이 긴 문서를 처리할 때 왜 느려질까요?
GPT-3 컨텍스트: 2,048 토큰
GPT-4 컨텍스트: 128,000 토큰
Llama 3: 1,000,000 토큰
2년 만에 500배 늘어났어요.
이게 가능해진 핵심 기술이 FlashAttention이에요.
Attention이 뭔가
LLM은 텍스트를 읽을 때 모든 토큰이 다른 모든 토큰과 얼마나 관련있는지 계산해요.
입력: "나는 사과를 먹었다"
각 토큰이 다른 토큰과의 관계 점수 계산:
"나는" ↔ "사과를": 0.3
"나는" ↔ "먹었다": 0.8
"사과를" ↔ "먹었다": 0.7
...
수식으로는 이래요.
Attention(Q, K, V) = softmax(QK^T / √d_k) × V
Q(Query): "지금 처리 중인 토큰"
K(Key): "비교할 모든 토큰"
V(Value): "실제 정보"
QK^T: 모든 토큰 쌍의 유사도 행렬
softmax: 유사도를 확률로 변환
× V: 유사도에 따라 정보 가중합산
문제 — Attention은 메모리 병목이다
GPU는 두 종류의 메모리가 있어요.
HBM (High Bandwidth Memory, 느린 메모리):
→ 용량 큼 (H100: 80GB)
→ 속도 느림 (대역폭 3.35TB/s)
→ 모델 가중치, KV Cache 저장
SRAM (빠른 메모리, L2 Cache):
→ 용량 매우 작음 (H100: 50MB)
→ 속도 매우 빠름 (대역폭 ~33TB/s, 10배 빠름)
→ 현재 계산 중인 데이터 저장
기존 Attention 계산 과정이에요.
1. Q, K를 HBM에서 SRAM으로 읽기
2. QK^T 행렬 계산 (N×N 크기)
3. 결과를 HBM에 저장 ← 느린 메모리에 씀!
4. HBM에서 다시 읽어서 Softmax 계산
5. 결과를 HBM에 저장 ← 또 씀!
6. HBM에서 읽어서 × V 계산
7. 최종 결과 HBM에 저장
HBM ↔ SRAM 왔다갔다를 6번이나 해요.
진짜 문제:
토큰 수 N이 늘어날수록:
중간 행렬(QK^T) 크기: N × N
→ 10K 토큰: 10,000 × 10,000 = 1억 개 값 저장
→ 100K 토큰: 100억 개 값 저장 → SRAM에 못 들어감!
→ HBM에 써야 함 → 엄청나게 느려짐
메모리 사용량이 O(N²)으로 늘어요. 컨텍스트 길이에 제곱으로 비례해서 느려져요.
FlashAttention의 핵심 아이디어 — Tiling
중간 행렬을 HBM에 저장하지 않고 SRAM 안에서 처리하는 거예요.
기본 아이디어:
전체 행렬을 한번에 계산하는 대신
작은 블록(타일) 단위로 나눠서 SRAM 안에서 처리
예시로 설명하면 이래요.
기존 방식 (10페이지 문서 요약):
1. 10페이지 전체를 화이트보드에 복사
(HBM에 N×N 행렬 저장)
2. 모든 문장 쌍을 비교
3. 요약 작성
FlashAttention 방식:
1. 1페이지씩 쪽지(SRAM)에 메모
2. 비교하고 중간 결과 업데이트
3. 쪽지 내용 지우고 다음 페이지
4. 화이트보드(HBM) 사용 최소화
실제 알고리즘:
for 블록 i in Q: # Q를 블록으로 나눔
for 블록 j in K, V: # K, V를 블록으로 나눔
# SRAM에서 계산
S_ij = Q_i × K_j^T # 유사도
P_ij = softmax(S_ij) # 확률
O_i += P_ij × V_j # 가중합
# HBM 접근: 처음 Q,K,V 읽기 + 최종 O 쓰기만!
# 중간 행렬 HBM 저장 없음
이게 되려면 Softmax를 블록 단위로 계산해야 하는데 이게 수학적으로 쉽지 않아요.
Online Softmax가 이걸 해결했어요.
일반 Softmax:
전체 값을 다 봐야 분모 계산 가능
→ N×N 행렬 전체 필요
Online Softmax:
블록 하나 볼 때마다 최댓값 추적
이전 블록 결과를 새 정보로 점진적 업데이트
→ 전체 행렬 없이도 정확한 결과 가능!
FlashAttention 버전별 발전
FlashAttention 1 (2022)
개선:
- HBM 접근 횟수 대폭 감소
- 메모리: O(N²) → O(N)
- 속도: 2~4배 향상
- 정확도: 기존과 동일 (근사 아님)
한계:
- GPU 사용률 35% 수준
- 배치 크기와 헤드 수 기준으로만 병렬화
FlashAttention 2 (2023)
개선:
- 시퀀스 길이 방향으로도 병렬화
- GPU 사용률 70%로 향상
- FA1 대비 2배 빠름
특히 좋은 점:
- 추론에서 배치 크기 작고 시퀀스 길 때 효율적
- Ampere (A100), Ada (RTX 4090) GPU 최적화
FlashAttention 3 (2024)
개선:
- H100 GPU 특성 완전 활용
- Producer-Consumer 비동기 처리
- Tensor Core와 데이터 이동 겹쳐서 실행
- FP8 지원으로 정확도 유지하면서 더 빠름
결과:
FA2 대비 1.5~2.0배 빠름
H100에서 740 TFLOPS (이론 최대의 75%)
FP8: 1.2 PFLOPS
실제로 얼마나 빨라지나
컨텍스트 길이별 속도 향상 (A100 기준):
2K 토큰: FA 2.0x 빠름
8K 토큰: FA 4.0x 빠름
32K 토큰: FA 6.0x 빠름
128K 토큰: FA 기본 Attention 불가능 → FA로만 가능
메모리:
2K: 기본 1.6GB → FA 0.2GB
8K: 기본 25GB → FA 0.8GB (기본은 OOM!)
32K: 기본 400GB → FA 3GB
FlashAttention 없으면 긴 컨텍스트 자체가 불가능해요.
SGLang에서 FlashAttention 활용
SGLang은 FA3를 기본으로 사용해요. 별도 설정 없이 자동으로 켜져 있어요.
# FA3 자동 적용 (기본값)
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.3-70B-Instruct \
--port 30000
# FA 버전 확인
python -c "
import sglang
print(sglang.__version__)
# 서버 시작 로그에서 확인:
# INFO: Using FlashAttention-3 backend
"
긴 컨텍스트 서빙:
# 1M 컨텍스트 서빙 (FA 없으면 불가능)
python -m sglang.launch_server \
--model-path Qwen/Qwen3.5-9B-Instruct \
--context-length 1000000 \
--port 30000
Python에서 직접 사용하기
PyTorch 2.x (자동 적용):
import torch
import torch.nn.functional as F
# PyTorch 2.x에서 F.scaled_dot_product_attention은
# 자동으로 FlashAttention 사용
q = torch.randn(1, 8, 4096, 64).cuda().half()
k = torch.randn(1, 8, 4096, 64).cuda().half()
v = torch.randn(1, 8, 4096, 64).cuda().half()
# FlashAttention 자동 사용
output = F.scaled_dot_product_attention(q, k, v)
Hugging Face Transformers:
from transformers import AutoModelForCausalLM
# FA2 명시적 사용
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.3-70B-Instruct",
attn_implementation="flash_attention_2", # FA2
torch_dtype=torch.float16,
device_map="auto"
)
# FA3 (H100만)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.3-70B-Instruct",
attn_implementation="flash_attention_3", # FA3
torch_dtype=torch.float16,
device_map="auto"
)
직접 설치:
pip install flash-attn --no-build-isolation
# FA3 (H100만)
pip install flash-attn-3 --no-build-isolation
LLM 추론 최적화 전체 그림
FlashAttention이 전체 최적화 스택의 어디에 위치하는지 정리해요.
레이어 1: 모델 레벨
- 양자화 (INT4/INT8): 모델 크기 50% 감소
- 프루닝: 불필요한 가중치 제거
- 지식 증류: 작은 모델로 압축
레이어 2: 시스템 레벨
- FlashAttention: Attention 계산 자체를 빠르게 ← 여기
- KV Cache: 이전 계산 저장 재사용
- PagedAttention: KV Cache 메모리 효율화
- Continuous Batching: 여러 요청 동시 처리
레이어 3: 애플리케이션 레벨
- 프롬프트 캐싱: 반복 프롬프트 재사용
- 컨텍스트 압축: 불필요한 토큰 제거
- 라우팅: 모델 선택 최적화
세 레이어 전부 적용하면 비용 80% 이상 절감 가능해요.
반응형
'LLM' 카테고리의 다른 글
| SGLang B300 GPU (SM103)에서 Qwen3.5 서빙 — Attention Backend (0) | 2026.04.15 |
|---|---|
| SGLang Attention Backend 완전 비교 — Triton, FlashInfer, FA3, TRTLLM (0) | 2026.04.15 |
| vLLM, SGLang이 빠른 이유 — Continuous Batching 원리와 실전 (0) | 2026.04.15 |
| SLM 실전 가이드 — Gemma 4, Qwen3.5, Phi-4로 API 비용 95% 줄이는 법 (1) | 2026.04.15 |
| Qwen 3.5 완전 분석 — 397B 파라미터인데 왜 저렴하고 빠른가 (0) | 2026.04.15 |