최근 RAG(Retrieval Augmented Generation)과 같은 반환 기반의 솔루션들이 많이 사용되면서 정보 반환의 중요성이 더 강조되어지고 있다.
물론 벡터 데이터베이스 등의 다른 솔루션을 사용해서 문제를 해결할 수도 있으나, 상황에 따라 상위 K개의 결과를 반환하는 알고리즘 및 기능을 직접 구현해야 하기도 할 것이다.
FAISS나 ChromaDB도 좋은 툴이지만, 개인적으로는 torch.topk() 메소드를 사용하는 유사도 반환 방식 역시 알아둘 필요가 있다고 본다.
가장 큰 이유는, FAISS나 ChromaDB 등의 솔루션의 경우에는 상당한 양의 메모리를 소모한다.
더 빠른 결과의 반환을 위해 인덱싱을 하다보니 그만큼 많은 메모리를 사용하게 된다.
이는 다시 말해, 물리적 한계로 인해 유사도 계산을 할 수 있는 양이 한정되어 있다는 뜻이 된다.
캐글이나 데이콘 같은 컴페티션 등에서 유사도 기반의 RAG 솔루션을 구축할 때 더 많은 데이터를 사용하고 싶지만 FAISS 등의 메모리 한계로 인해 좌절되는 경험을 회피할 수 있는 가장 최적의 방법 중 하나가 바로 이 topk 메소드 기반 유사도 탐색을 직접 구현하는 것이다.
실제로 2023년 가을에 진행된 "Kaggle LLM Exam"이라는 대회에서 독보적으로 1등을 차지한 팀에서 사용한 RAG 기반 솔루션에서는 이 torch.topk를 활용한 커스텀 유사도 탐색 기법을 활용한다.
# values에는 값, indexs에는 인덱스 값이 반환됨
values, indexs = torch.topk(predict, k=k, dim=-1)
위의 예시 코드에서 사용하는 torch.topk() 메소드는 argmax와 같이 "최대 값"을 찾는 것이 아니라 가장 큰 k개의 값을 찾도록 도와준다.
이를 cosine 유사도 등과 병합하면 가장 높은 k개의 코사인 유사도를 가지는 청크를 찾는 retrieval 메소드를 구현할 수 있다.
아래는 예시 코드이다:
def cos_similarity_matrix(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8):
"""Calculates cosine similarities between tensor a and b."""
sim_mt = torch.mm(a, b.transpose(0, 1))
return sim_mt
def get_topk(embeddings_from, embeddings_to, topk=1000, bs=512):
chunk = bs
embeddings_chunks = embeddings_from.split(chunk)
vals = []
inds = []
for idx in tqdm(range(len(embeddings_chunks))):
cos_sim_chunk = cos_similarity_matrix(embeddings_chunks[idx].to(embeddings_to.device).half(), embeddings_to).float()
cos_sim_chunk = torch.nan_to_num(cos_sim_chunk, nan=0.0)
topk = min(topk, cos_sim_chunk.size(1))
vals_chunk, inds_chunk = torch.topk(cos_sim_chunk, k=topk, dim=1)
vals.append(vals_chunk[:, :].detach().cpu())
inds.append(inds_chunk[:, :].detach().cpu())
del vals_chunk
del inds_chunk
del cos_sim_chunk
vals = torch.cat(vals).detach().cpu()
inds = torch.cat(inds).detach().cpu()
return inds, vals
코드 출처: https://www.kaggle.com/code/ybabakhin/1st-place-team-h2o-llm-studio
1st place. Team H2O LLM Studio
Explore and run machine learning code with Kaggle Notebooks | Using data from multiple data sources
www.kaggle.com
'AI > PyTorch' 카테고리의 다른 글
cuDNN benchmark 활성화를 통한 최적의 알고리즘 선택 (0) | 2023.10.14 |
---|---|
PyTorch에서 이미지 데이터에 대해 normalize를 할 때, mean=[0.485, 0.456, 0.406]과 std=[0.229, 0.224, 0.225]를 쓰는 이유는? (2) | 2023.10.09 |
Softmax 결과의 총합이 1 이하로 나오는 경우 (0) | 2022.07.18 |
Torch.mm과 Torch.matmul 차이점 (0) | 2022.03.16 |