AI/PyTorch

텐서에서 Top-K 결과를 받아오는 방법

검정비니 2023. 10. 26. 17:29
728x90
반응형

최근 RAG(Retrieval Augmented Generation)과 같은 반환 기반의 솔루션들이 많이 사용되면서 정보 반환의 중요성이 더 강조되어지고 있다.

물론 벡터 데이터베이스 등의 다른 솔루션을 사용해서 문제를 해결할 수도 있으나, 상황에 따라 상위 K개의 결과를 반환하는 알고리즘 및 기능을 직접 구현해야 하기도 할 것이다.

FAISS나 ChromaDB도 좋은 툴이지만, 개인적으로는 torch.topk() 메소드를 사용하는 유사도 반환 방식 역시 알아둘 필요가 있다고 본다.

 

가장 큰 이유는, FAISS나 ChromaDB 등의 솔루션의 경우에는 상당한 양의 메모리를 소모한다.

더 빠른 결과의 반환을 위해 인덱싱을 하다보니 그만큼 많은 메모리를 사용하게 된다.

이는 다시 말해, 물리적 한계로 인해 유사도 계산을 할 수 있는 양이 한정되어 있다는 뜻이 된다.

 

캐글이나 데이콘 같은 컴페티션 등에서 유사도 기반의 RAG 솔루션을 구축할 때 더 많은 데이터를 사용하고 싶지만 FAISS 등의 메모리 한계로 인해 좌절되는 경험을 회피할 수 있는 가장 최적의 방법 중 하나가 바로 이 topk 메소드 기반 유사도 탐색을 직접 구현하는 것이다.

실제로 2023년 가을에 진행된 "Kaggle LLM Exam"이라는 대회에서 독보적으로 1등을 차지한 팀에서 사용한 RAG 기반 솔루션에서는 이 torch.topk를 활용한 커스텀 유사도 탐색 기법을 활용한다.

  # values에는 값, indexs에는 인덱스 값이 반환됨
  values, indexs = torch.topk(predict, k=k, dim=-1)

위의 예시 코드에서 사용하는 torch.topk() 메소드는 argmax와 같이 "최대 값"을 찾는 것이 아니라 가장 큰 k개의 값을 찾도록 도와준다.

이를 cosine 유사도 등과 병합하면 가장 높은 k개의 코사인 유사도를 가지는 청크를 찾는 retrieval 메소드를 구현할 수 있다.

 

아래는 예시 코드이다:

def cos_similarity_matrix(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8):
    """Calculates cosine similarities between tensor a and b."""
    sim_mt = torch.mm(a, b.transpose(0, 1))
    return sim_mt


def get_topk(embeddings_from, embeddings_to, topk=1000, bs=512):
    chunk = bs
    embeddings_chunks = embeddings_from.split(chunk)

    vals = []
    inds = []
    for idx in tqdm(range(len(embeddings_chunks))):
        cos_sim_chunk = cos_similarity_matrix(embeddings_chunks[idx].to(embeddings_to.device).half(), embeddings_to).float()

        cos_sim_chunk = torch.nan_to_num(cos_sim_chunk, nan=0.0)

        topk = min(topk, cos_sim_chunk.size(1))
        vals_chunk, inds_chunk = torch.topk(cos_sim_chunk, k=topk, dim=1)
        vals.append(vals_chunk[:, :].detach().cpu())
        inds.append(inds_chunk[:, :].detach().cpu())

        del vals_chunk
        del inds_chunk
        del cos_sim_chunk

    vals = torch.cat(vals).detach().cpu()
    inds = torch.cat(inds).detach().cpu()

    return inds, vals

코드 출처: https://www.kaggle.com/code/ybabakhin/1st-place-team-h2o-llm-studio

 

1st place. Team H2O LLM Studio

Explore and run machine learning code with Kaggle Notebooks | Using data from multiple data sources

www.kaggle.com

 

반응형