Dev.log

Sentence BERT 본문

자연어처리

Sentence BERT

포켓몬빵 2022. 5. 11. 22:53

SBERT

SBERT(Sentence BERT)는 BERT의 임베딩 성능을 향상시킨 모델입니다. BERT로 부터 문장벡터를 얻을때는 BERT의 [CLS] 토큰의 출력 벡터를 문장 벡터로 간주하거나 각 task에 맞춰 모든 단어의 의미를 반영할건지 중요한 단어의 의미를 반영할건지에 따라 각각 모든 단어의 출력 벡터에 대해서 average pooling 과 max pooling을 수행하여 문장 벡터로 얻을 수 있습니다.

 

SBERT는 이와 같은 BERT의 문장 임베딩을 응용하여 BERT에 fine tunning을 진행합니다. SBERT는 크게 2가지 방법으로 학습이 진행된다고 할 수 있는데, NLI(Natural Language Inferencing) 문제와 같은 문장 쌍분류 테스크를 통해 fine tunning을 진행할 수 있고, STS(Semantic Textual Similarity)문제와 같이  문장 쌍으로 회귀 문제를 푸는 방식으로 Fine tuning을 진행 할 수 도 있습니다. 먼저 NLI(Natural Language Inferencing) Task의 경우 아래와 같이 두 개의 문장이 주어졌을때 수반 관계인지, 모순관계인지를 맞춰야 한다고 합시다.

 

문장 A 문장 B Label
A lady sits on a bench that is against a shopping mall. A person sits on the seat. Entailment
... ... ...

 

SBERT는 이런 NLI 데이터를 학습하기 위해  문장 A와 문장 B 각각을 BERT의 입력으로 넣고, average pooling 또는 max pooling을 통해서 각각에 대한 문장 임베딩 벡터를 얻습니다. 여기서 이를 각각 u와 v라고 할때, u벡터와 v벡터의 차이 벡터를 구합니다. 이 벡터는 수식으로 표현하면 |u-v|입니다. 그리고 이 세 가지 벡터를 연결(concatenation)합니다. 세미콜론(;)을 연결 기호로 한다면 연결된 벡터의 수식은 아래와 같이 나타 낼 수 있습니다.

만약 BERT의 문장 임베딩 벡터가 n 차원이라면 세 개의 벡터를 연결한 벡터 h의 차원은 3n이 됩니다. 그리고 이 벡터를 출력층으로 보내 Multi-class classification task를 풀도록 합니다. 즉, 분류하고자 하는 클래스의 개수가 k라면, 가중치 행렬 3n × k의 크기를 가지는 행렬 Wy을 곱한 후에 소프트맥스 함수를 통과시킵니다. 이를 수식으로 표현하면 아래와 같습니다.

SBERT를 통해 두 개의 문장으로부터 의미적 유사성을 구하는 STS(Semantic Textual Similarity)와 같은 Task를 학습할때, SBERT는 문장 A와 문장 B 각각을 BERT의 입력으로 넣고, average pooling 또는 max pooling을 통해서 각각에 대한 문장 임베딩 벡터를 얻습니다.

 

이를 각각 u와 v라고 하였을 때 이 두 벡터의 코사인 유사도를 구한뒤 해당 유사도와 레이블 유사도와의 평균 제곱 오차(Mean Squared Error, MSE)를 최소화하는 방식으로 학습합니다. 코사인 유사도의 값의 범위는 -1과 1사이이고 만일 레이블된 값의 범위가 이보 다 크면, 각 레이블된 값의 법위에 따라 학습전 레이블들의 값을 각 기준에 맞춰 나누어 범위를 줄인뒤 학습을 진행합니다.

Comments