AI/LLM

Multi-Query Attention 설명

검정비니 2023. 10. 25. 21:43
728x90
반응형

효율적인 Inference

요약, 질의응답(Q&A), 검색 증강 생성 등 언어 작업에 효과적인 기술로 Transformer 아키텍처에 기반한 대규모 언어 모델(LLM)이 부상했다. 하지만 이러한 모델을 사용하려면 계산 비용이 매우 많이 들며, 주로 NVIDIA GPU와 같은 컴퓨팅 가속기를 통해 실행된다.

 

LLM에 대한 입력과 출력은 토큰 시퀀스(예: 단어)로 표현됩니다. 긴 시퀀스(즉, 컨텍스트 창이 긴)를 처리할 수 있는 LLM을 훈련하거나 미세 조정하는 것은 활발히 발전하고 있는 분야이다. 대부분의 OSS LLM 기본 모델은 2K 컨텍스트 창으로 사전 학습된다. 문서 요약이나 컨텍스트 기반 질문 답변과 같이 점점 더 많은 사용 사례에서 LLM이 처리하는 시퀀스 길이는 수천에서 수만 개의 토큰으로 상당히 커질 수 있다. 앞으로는 대부분의 LLM 사용 사례에서 긴 시퀀스 길이가 새로운 표준이 될 것으로 예상됩니다. 그러나 긴 시퀀스는 추론 비용에도 상당한 영향을 미친다. 

 

추론을 위한 시스템 성능은 다음과 같은 몇 가지 기술을 통해 모델을 변경하지 않고도 개선할 수 있습니다:

  • 추론 프로세스의 반복 사이에 계산된 상태 저장(KV 캐싱)
  • 추론 중에 여러 시퀀스를 함께 일괄 처리하여 계산 리소스를 재사용하고, 확장하여 동시 요청을 지속적으로 일괄 처리
  • 메모리 조각화를 줄이고 배치 크기를 최대화하기 위한 메모리 할당 진행

그러나 추론 성능을 개선하는 가장 효과적인 방법은 모델 아키텍처와 시스템 아키텍처를 공동으로 설계하는 것이다. 이 글에서는 추론 연산에 필요한 메모리 공간과 메모리 대역폭을 획기적으로 줄여주는 다중 쿼리 주의(MQA)라는 공동 기법을 중점적으로 다룬다.

공간 절약은 토큰 수에 비례하므로 긴 시퀀스에 특히 유용하다. MQA를 최적화하면 벤치마크에서 MQA를 사용하지 않고 공개적으로 사용 가능한 최상의 기준선에 비해 처리량이 11배 향상되고 지연 시간이 30% 감소할 수 있다.

 

Multi-Head Attention

Transformer 논문에서 처음 제기된 Multi-Head Attention은 입력을 Query, Key, Value로 각각 투영해서 계산되는 Self Attention을 병렬성을 높인 버전이다.

여기서 h는 연산에서 "헤드"의 수를 나타내고, S와 L은 각각 입력 및 출력 시퀀스 길이를 나타내며, d_k는 모델 아키텍처의 숨겨진 차원을 나타낸다.

아래는 Multi-Head Attention의 파이썬 예시 코드이다.

Q = torch.randn(N, h, S, d_k)
K = torch.randn(N, h, L, d_k)
V = torch.randn(N, h, L, d_k)


# <...>

logits = torch.matmul(Q, K.transpose(2, 3)) # Output shape [N, h, S, L]
softmax_out = torch.softmax(logits / math.sqrt(d_k), dim=-1) # Output shape [N, h, S, L]
attn_out = torch.matmul(softmax_out, V) # Output shape [N, h, S, d_k]

두 가지 시퀀스 길이가 있는데, 하나는 Q 값에 적용되는 길이이고 다른 하나는 K와 V 값 모두에 적용되는 길이이다.

추론하는 동안 일반적으로 네트워크에 한 번에 하나의 토큰씩 값을 점진적으로 공급하고(즉, S = 1) 지금까지 표시된 토큰에 대해 K와 V를 계산하는 증분 생성을 사용하게 된다(즉, 생성이 진행됨에 따라 L이 증가).

결과적으로 출력 시퀀스가 생성됨에 따라 K와 V는 점진적으로 증가하며, 일반적인 최적화 기법은 반복에 걸쳐 변경 가능한 KV 캐시를 사용하는 것이다. 그러면 멀티헤드 어텐션의 내부 루프는 다음과 같이 진행된다:

# Cached K and V values across iterations
K = torch.randn(N, h, ..., d_k)
V = torch.randn(N, h, ..., d_k)

# Single-step QKV values computed during sequence generation
Q_incr = torch.randn(N, h, 1, d_k)
K_incr = torch.randn(N, h, 1, d_k)
V_incr = torch.randn(N, h, 1, d_k)

# <...>

# Update KV-cache
K = torch.cat([K, K_incr], dim=-2)
V = torch.cat([V, V_incr], dim=-2)

# Compute attention (L is sequence length so far)
logits = torch.matmul(Q_incr, K.transpose(2, 3)) # Output shape [N, h, 1, L]
softmax_out = torch.softmax(logits / math.sqrt(d_k), dim=-1) # Output shape [N, h, 1, L]
attn_out = torch.matmul(softmax_out, V) # Output shape [N, h, 1, d_k]

 

Multi-Query Attention

https://arxiv.org/abs/1911.02150 논문에서 처음 제기된 멀티 쿼리 어텐션은 위의 멀티헤드 어텐션에서 쿼리만 h개의 헤드를 가지게 하고, 나머지 k와 v는 단일 헤드로 진행되도록 하는 구조이다.

# Cached K and V values across iterations
K = torch.randn(N, ..., d_k)
V = torch.randn(N, ..., d_k)

# Single-step QKV values computed during sequence generation
Q_incr = torch.randn(N, h, 1, d_k)
K_incr = torch.randn(N, 1, d_k)
V_incr = torch.randn(N, 1, d_k)

# <...>

# Update KV-cache
K = torch.cat([K, K_incr], dim=-2)
V = torch.cat([V, V_incr], dim=-2)

# Compute attention (L is sequence length so far)
# NB: K is broadcasted (repeated) out across Q's `h` dimension!
logits = torch.matmul(Q_incr, K.transpose(2, 3)) # Output shape [N, h, 1, L]
softmax_out = torch.softmax(logits / math.sqrt(d_k), dim=-1) # Output shape [N, h, 1, L]
# NB: V is broadcasted (repeated) out across softmax_out's `h` dimension!
attn_out = torch.matmul(softmax_out, V) # Output shape [N, h, 1, d_k]

이 멀티 쿼리 어텐션의 장점은 더 적은 연산량과 더 적은 메모리를 요구한다는 점이다.

즉, 더 빠른 inference와 더 적은 GPU 메모리가 요구되어진다는 것이다.

 

 

실제로 2023년 10월 현재 오픈소스 LLM 모델들 중 가장 뛰어난 성능을 보이고 있는 Zephyr-7b-alpha, Mistral-7B, Llama2-70B 모두 이 멀티 쿼리 어텐션을 통해 더 빠르고 가벼운 모델을 만들면서 더 큰 모델인 GPT-3.5보다 더 뛰어난 성능을 이루는데 성공하였다.

 

하드웨어적 효율성이 더 뛰어나면서, LLM 학습 시 기존 멀티헤드 어텐션 기반 모델과 비슷하거나 그 이상의 성능을 낼 수도 있기 때문에 충분히 가치가 있는 방식이라고 볼 수 있다.

 

반응형

'AI > LLM' 카테고리의 다른 글

LLM 연구의 주요 과제들  (0) 2023.10.14