AI/Transformers

HuggingFace Transformers에서 Zero-shot Classification의 원리

검정비니 2024. 3. 6. 12:46
728x90
반응형

HuggingFace Transformers 라이브러리에서는 다양한 파이프라인들을 제공한다.

그 중 하나로 ZeroShotClassificationPipeline 이라는 것이 있는데, 미리 학습되지 않은 라벨들에 대해서 분류를 하는 기법이다.

 

알다시피, 모델 학습에는 다양한 공수와 리소스가 소요되기 때문에 현업에서는 필요에 따라 ZeroShot 혹은 FewShot 모델들을 잘 활용해야 하는 경우가 많이 있다.

 

GPT-4와 같은 생성형 모델에서는 LLM의 성능을 통해 이러한 Zero Shot Classification을 수행하게 되는데, BERT와 같은 임베딩 모델 기반 환경에서는 Zero Shot Classification을 어떻게 다루게 될까?

 

허깅페이스 트랜스포머 소스코드 (transformers/src/transformers/pipelines/zero_shot_classification.py)를 살펴보면서 그 원리를 이해해보고자 한다.

 

기본적으로 전처리를 위한 preprocess 메소드와 토크나이징을 수행하는 _parse_and_tokenize 메소드는 아래와 같다:

    def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."):
        sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template)

        for i, (candidate_label, sequence_pair) in enumerate(zip(candidate_labels, sequence_pairs)):
            model_input = self._parse_and_tokenize([sequence_pair])

            yield {
                "candidate_label": candidate_label,
                "sequence": sequences[0],
                "is_last": i == len(candidate_labels) - 1,
                **model_input,
            }


	...


    def _parse_and_tokenize(
        self, sequence_pairs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.ONLY_FIRST, **kwargs
    ):
        """
        Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
        """
        return_tensors = self.framework
        if self.tokenizer.pad_token is None:
            # Override for tokenizers not supporting padding
            logger.error(
                "Tokenizer was not supporting padding necessary for zero-shot, attempting to use "
                " `pad_token=eos_token`"
            )
            self.tokenizer.pad_token = self.tokenizer.eos_token
        try:
            inputs = self.tokenizer(
                sequence_pairs,
                add_special_tokens=add_special_tokens,
                return_tensors=return_tensors,
                padding=padding,
                truncation=truncation,
            )
        except Exception as e:
            if "too short" in str(e):
                # tokenizers might yell that we want to truncate
                # to a value that is not even reached by the input.
                # In that case we don't want to truncate.
                # It seems there's not a really better way to catch that
                # exception.

                inputs = self.tokenizer(
                    sequence_pairs,
                    add_special_tokens=add_special_tokens,
                    return_tensors=return_tensors,
                    padding=padding,
                    truncation=TruncationStrategy.DO_NOT_TRUNCATE,
                )
            else:
                raise e

        return inputs

 

우선 전처리 과정에서는 입력 텍스트와 대상 라벨들을 sequence_pairs라는 형태로 변환한 후, _parse_and_tokenize 메소드에서 토크나이징을 거치게 된다. 이렇게 문장 pair 한번에 배치 형태로 토크나이징하는 형식을 진행하게 되면 tokenizer는 두 문장을 하나의 긴 토큰 배열로 생성하면서, 각 문장의 구분을 위해 sentence mask들을 생성하게 된다.

이 sentence mask는 간단히 설명하자면 [0, 0, 0, 0, 0, 1, 1, 1, 1]과 같은 포맷으로 생성되며, 0에 해당되는 토큰들이 첫번째 문장에 해당하고 1에 해당하는 토큰들이 두번째 문장에 해당하게 된다. (위의 값은 순전히 예시를 위한 값으로, 실제 값은 다르게 나올 수 있다)

 

아래 스크린샷은 2개의 sequence_pair에 대해서 토크나이징한 결과다.

tokenize the sequence_pairs

보다시피, "token_type_ids"라는 항목을 보면 두 문장을 구분하기 위해서 token_type_id를 0과 1로 구분하게 된다 (0이 첫번째 문장, 1이 두번째 문장의 토큰임을 나타낸다).

이와 같은 방식으로 Zero Shot Classification을 위해 단일 input text를 여러개의 paired text로 생성하고, 각각을 tokenizing한 뒤, 배치 형식으로 추론을 하게 되는 것이다.

 

그 후, _forward 메소드를 보게 되면 이 토크나이징된 input 데이터를 model에게 넘겨서 순전파(forward propagation) 과정을 진행하게 된다.

    def _forward(self, inputs):
        candidate_label = inputs["candidate_label"]
        sequence = inputs["sequence"]
        model_inputs = {k: inputs[k] for k in self.tokenizer.model_input_names}
        # `XXXForSequenceClassification` models should not use `use_cache=True` even if it's supported
        model_forward = self.model.forward if self.framework == "pt" else self.model.call
        if "use_cache" in inspect.signature(model_forward).parameters.keys():
            model_inputs["use_cache"] = False
        outputs = self.model(**model_inputs)

        model_outputs = {
            "candidate_label": candidate_label,
            "sequence": sequence,
            "is_last": inputs["is_last"],
            **outputs,
        }
        return model_outputs

 

 

대표적인 언어 임베딩 모델들 중 하나인 BERT의 경우, 이런식으로 2개의 문장을 한번에 입력받아서 두 문장이 서로 이어지는 문장인지에 대해서 평가하는 "Next Sentence Prediction"을 과정을 학습하기도 하는데, 이 Zero shot classification은 이 방법을 응용한 것이다.

 

각 대상 라벨과 텍스트를 연속되는 문장의 형태로 만든 뒤, 어느 라벨이 텍스트의 다음 문장으로서 가장 자연스러운지에 대한 평가를 마치 Next Sentece Prediction을 수행하는 것처럼 진행하는 것이다.

 

이후 후처리를 담당하는 postprocess 메소드를 보면 각 sentence_pair에 대한 결과 logit에 대해 softmax를 적용해서 결과값을 로짓이 아닌 확률값으로 변경하는 것을 볼 수가 있다.

    def postprocess(self, model_outputs, multi_label=False):
        candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
        sequences = [outputs["sequence"] for outputs in model_outputs]
        logits = np.concatenate([output["logits"].numpy() for output in model_outputs])
        N = logits.shape[0]
        n = len(candidate_labels)
        num_sequences = N // n
        reshaped_outputs = logits.reshape((num_sequences, n, -1))

        if multi_label or len(candidate_labels) == 1:
            # softmax over the entailment vs. contradiction dim for each label independently
            entailment_id = self.entailment_id
            contradiction_id = -1 if entailment_id == 0 else 0
            entail_contr_logits = reshaped_outputs[..., [contradiction_id, entailment_id]]
            scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
            scores = scores[..., 1]
        else:
            # softmax the "entailment" logits over all candidate labels
            entail_logits = reshaped_outputs[..., self.entailment_id]
            scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)

        top_inds = list(reversed(scores[0].argsort()))
        return {
            "sequence": sequences[0],
            "labels": [candidate_labels[i] for i in top_inds],
            "scores": scores[0, top_inds].tolist(),
        }

 

참고자료:

https://huggingface.co/tasks/zero-shot-classification

https://github.com/huggingface/transformers

 

반응형

'AI > Transformers' 카테고리의 다른 글

Transformers Decoder의 "past_key_values"에 대하여  (0) 2023.10.26