자연어 처리

[ 자연어 처리 ] BERT

예진또이(애덤스미스 아님) 2023. 9. 12. 02:24
728x90

1. BERT 모델

  • 2018년도 Google의 논문에서 처음 제안된 모델로, Transformer의 인코더 기반의 언어 모델
  • unlabeled data로 부터 pre-train을 진행한 후, 특정 downstream task에 fine-tuning을 하는 모델
  • deep bidirectional을 더욱 강조하여 기존의 모델들과의 차별성을 강조
  • 하나의 output layer만을 pre-trained BERT 모델에 추가하여 NLP의 다양한 주요 task(11개)에서 SOTA를 달성
 

1-1. BERT 모델 개요

  • LM의 pre-training 방법은 BERT 이전에도 많이 연구되고 있었고, 실제로 좋은 성능을 내고 있었음
  • 특히 문장 단위의 task에서 두각을 보였는데, 이러한 연구들은 두 문장의 관계를 전체적으로 분석하여 예측하는 것을 목표로 함
  • 문장 뿐만 아니라 토큰 단위의 task(개체명 인식, QA 등)에서도 좋은 성능을 보였음
  • downstream 작업에 사전 학습된 벡터 표현을 적용하는 방법
    • feature based approach
      • 대표적으로 ELMo
      • pre-trained representations을 하나의 추가적인 feature로 활용해 사용
    • fine-tuning approach
      • parameters 수는 최소화하고, 모든 pre-trained 파라미터를 조금만 바꿔서 downstream 작업을 학습

일반적으로 language representation을 학습하기 위해 uni-directional 언어 모델을 사용하는데 기존의 위 방법이 사전 학습의 성능을 떨어뜨린다고 주장

 

1-2. BERT 모델 구조

  • Pre-training part와 Fine-tuning part로 나눠짐
  • Pre-training 에서는 다양한 pre-training tasks의 unlabeled data를 활용해 파라미터를 설정하고, 이를 바탕으로 학습된 모델은 Fine-tuning에서 downstream tasks의 labeled data를 이용해 fine-tuning
  • 양방향 Transformer encoder를 여러 층 쌓은 것(multi-layer bidirectional Transformer encoder)
    • BERT base: 110M(약 1억1천만) 파라미터
    • BERT large: 340M(약 3억4천만)파라미터
  • BERT base모델은 OpenAI의 GPT와의 비교를 위해 파라미터 수를 동일하게 만들어 진행

1-3. BERT 입/출력

  • 총 3가지(Token, Segment, Position)의 Embedding vector를 합쳐서 input으로 사용
  • 모든 input 시퀀스의 첫번째 토큰은 [CLS] 토큰인데, [CLS] 토큰과 대응되는 최종 hidden state는 분류 문제를 해결하기 위해 sequence representation들을 함축
  • input 시퀀스는 한 쌍의 문장으로 구성되고, 문장 쌍의 각 문자들을 [SEP] 토큰으로 분리하고 각 문장이 A문장인지, B문장인지 구분하기 위한 임베딩을 사용
  • Token Embeddings는 WordPiece embedding을 사용
  • input representation은 이러한 대응되는 토큰(segment + token + position)을 전부 더함

1-4. BERT의 사전 학습

  • MLM: Masked Language Modeling
    • input tokens의 일정 비율을 마스킹하고, 마스킹 된 토큰을 예측하는 과정
    • 입력으로 들어온 단어 토큰 중 일부를 [MASK] token으로 바꿔서 학습
    • pre-training과 fine-tuning 사이의 mismatch가 발생([MASK] token이 fine-tuning 과정에서는 나타나지 않기 때문)
    • 위 문제를 해결하기 위해 token에서 추가적인 처리
      • 80%의 경우: token을 [MASK] token으로 바꿈
      • 10%의 경우: token을 random word로 바꿈
      • 10%의 경우: token을 원래 단어 그대로 둠
  • NSP: Next Sentence Prediction
    • 많은 NLP의 downstream task(QA, NLI 등) 두 문장 사이의 관계를 이행하는 것이 핵심
    • 문장 A와 B를 선택할 때, 50%는 실제 A의 다음 문장인 B를 고르고, 나머지 50%는 랜덤 문장 B에 고름
  • 사전 학습 과정은 많은 데이터를 필요로 하는데, corpus 구축을 위해 BooksCorpus(약 8억개의 단어)와 English Wikipedia(약 25억개의 단어)를 사용
  • Wikipedia는 본문만 사용했고, 리스트, 표, 헤더 등은 무시
  • 긴 인접 시퀀스를 뽑아내기 위해서는 문서 단위의 corpus를 사용하는 것이 문장 단위의 corpus를 사용하는 것 보다 훨씬 유리

2. BERT 요약

  • ELMo가 pretraining 관점을 제시했으면, GPT-1은 transformer 구조에 적용해서 transformer가 pretraining에 효과적이라는 것을 밝히고, BERT가 양방향으로 개선
  • Deep Bidirectional Model을 통해 같은 pretraining 모델로 만든 모든 NLP Task에서 SOTA를 달성
  • pretraining 모델을 통해 적은 리소스로도 좋은 성능을 낼 수 있음
  • 대신 사전 학습에 많은 시간과 비용이 필요
 

3. 간단한 답변 랭킹 모델 만들기

출처: https://github.com/songys/Chatbot_data

import urllib.request
import pandas as pd

urllib.request.urlretrieve("https://raw.githubusercontent.com/songys/Chatbot_data/master/ChatbotData.csv", filename='ChatBotData.csv')
-------------------------------------------------------------------------------
# 결과
('ChatBotData.csv', <http.client.HTTPMessage at 0x7a82599df5b0>)

--------------------------------------------------------------------------------
train_dataset = pd.read_csv('ChatBotData.csv')
print(len(train_dataset))
--------------------------------------------------------------------------------
# 결과
11823

----------------------------------------------------------------------------------
train_dataset

# 데이터셋 결측값 확인
train_dataset.replace('', float('NaN'), inplace=True)
print(train_dataset.isnull().values.any())
-------------------------------------------------------
# 결과
False

--------------------------------------------------------
# 데이터셋 중복 제거
train_dataset = train_dataset.drop_duplicates(['Q']).reset_index(drop=True)
print(len(train_dataset))
--------------------------------------------------------
# 결과
11662

--------------------------------------------------------
train_dataset = train_dataset.drop_duplicates(['A']).reset_index(drop=True)
print(len(train_dataset))
----------------------------------------------------------
# 결과
7731

----------------------------------------------------------
import matplotlib.pyplot as plt

question_list = list(train_dataset['Q'])
answer_list = list(train_dataset['A'])

print('질문의 최대 길이: ', max(len(question) for question in question_list))
print('질문의 평균 길이: ', sum(map(len, question_list))/len(question_list))

plt.hist([len(question) for question in question_list], bins=50)
plt.xlabel('length of samples')
plt.ylabel('number of samples')
plt.show()

# 결과

print('답변의 최대 길이: ', max(len(answer) for answer in answer_list))
print('답변의 평균 길이: ', sum(map(len, answer_list))/len(answer_list))

plt.hist([len(answer) for answer in answer_list], bins=50)
plt.xlabel('length of samples')
plt.ylabel('number of samples')
plt.show()

# 결과

import random

print(f'question 개수: {len(question_list)}')
print(f'answer 개수: {len(answer_list)}')
-----------------------------------------------
# 결과
question 개수: 7731
answer 개수: 7731

-----------------------------------------------
response_candidates = random.sample(answer_list, 500)

response_candidates[:10]
------------------------------------------------
# 결과
['마음에서 놓아주세요.',
 '헤어짐는 순간은 악몽과 같아요.',
 '자신을 더 사랑해주세요.',
 '그 사람이 좋아하는 걸 알아보세요.',
 '맞는 말이에요.',
 '좋은 기억들만 남았길 바랄게요.',
 '좋아하지 않는다고 생각해서 좋아해지지 않는다면요.',
 '허전한 마음이 들겠어요.',
 '천천히 마음을 정리하는게 필요하겠어요.',
 '제대로 된 표현을 했는지 생각해보세요.']
 
 ------------------------------------------------------

KoBERT-Transformer 모델 불러오기

SKTBrain에서 공개한 한국어 데이터로 사전학습한 BERT 모델

KoBERT : https://github.com/SKTBrain/KoBERT

KoBERT-Transformers: https://github.com/monologg/KoBERT-Transformers

 

GitHub - monologg/KoBERT-Transformers: KoBERT on 🤗 Huggingface Transformers 🤗 (with Bug Fixed)

KoBERT on 🤗 Huggingface Transformers 🤗 (with Bug Fixed) - GitHub - monologg/KoBERT-Transformers: KoBERT on 🤗 Huggingface Transformers 🤗 (with Bug Fixed)

github.com

!pip install kobert-transformers
------------------------------------------
# 결과
Collecting kobert-transformers
  Downloading kobert_transformers-0.5.1-py3-none-any.whl (12 kB)
Requirement already satisfied: torch>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from kobert-transformers) (2.0.1+cu118)
Collecting transformers<5,>=3 (from kobert-transformers)
  Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.2/7.2 MB 49.3 MB/s eta 0:00:00
Collecting sentencepiece>=0.1.91 (from kobert-transformers)
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 64.2 MB/s eta 0:00:00
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.1.0->kobert-transformers) (3.12.2)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.1.0->kobert-transformers) (4.7.1)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.1.0->kobert-transformers) (1.11.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.1.0->kobert-transformers) (3.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.1.0->kobert-transformers) (3.1.2)
Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.1.0->kobert-transformers) (2.0.0)
Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.1.0->kobert-transformers) (3.25.2)
Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.1.0->kobert-transformers) (16.0.6)
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers<5,>=3->kobert-transformers)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 268.8/268.8 kB 28.2 MB/s eta 0:00:00
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers<5,>=3->kobert-transformers) (1.22.4)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers<5,>=3->kobert-transformers) (23.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers<5,>=3->kobert-transformers) (6.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers<5,>=3->kobert-transformers) (2022.10.31)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers<5,>=3->kobert-transformers) (2.27.1)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers<5,>=3->kobert-transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.8/7.8 MB 94.5 MB/s eta 0:00:00
Collecting safetensors>=0.3.1 (from transformers<5,>=3->kobert-transformers)
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 72.1 MB/s eta 0:00:00
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers<5,>=3->kobert-transformers) (4.65.0)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers<5,>=3->kobert-transformers) (2023.6.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.1.0->kobert-transformers) (2.1.3)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers<5,>=3->kobert-transformers) (1.26.16)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers<5,>=3->kobert-transformers) (2023.5.7)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers<5,>=3->kobert-transformers) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers<5,>=3->kobert-transformers) (3.4)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.1.0->kobert-transformers) (1.3.0)
Installing collected packages: tokenizers, sentencepiece, safetensors, huggingface-hub, transformers, kobert-transformers
Successfully installed huggingface-hub-0.16.4 kobert-transformers-0.5.1 safetensors-0.3.1 sentencepiece-0.1.99 tokenizers-0.13.3 transformers-4.30.2
import torch
from kobert_transformers import get_kobert_model, get_distilkobert_model

model = get_kobert_model()

# 결과

model.eval()
-------------------------------------
# 결과
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(8002, 768, padding_idx=1)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
attention_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
output = model(input_ids, attention_mask, token_type_ids)
output

----------------------------------------------------------
# 결과
BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.2461,  0.2428,  0.2590,  ..., -0.4861, -0.0731,  0.0756],
         [-0.2478,  0.2420,  0.2552,  ..., -0.4877, -0.0727,  0.0754],
         [-0.2472,  0.2420,  0.2561,  ..., -0.4874, -0.0733,  0.0765]],

        [[ 0.0768, -0.1234,  0.1534,  ..., -0.2518, -0.2571,  0.1602],
         [-0.2419, -0.2821,  0.1962,  ..., -0.0172, -0.2960,  0.3679],
         [ 0.0911, -0.1437,  0.3412,  ...,  0.2526, -0.1780,  0.2619]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[-0.0903, -0.0444,  0.1579,  ...,  0.1010, -0.0819,  0.0529],
        [ 0.0742, -0.0116, -0.6845,  ...,  0.0024, -0.0447,  0.0122]],
       grad_fn=<TanhBackward0>), hidden_states=None, past_key_values=None, attentions=None, cross_attentions=None)
       
       -----------------------------------------------------
       output[0]
       -----------------------------------------------------
       # 결과
       tensor([[[-0.2461,  0.2428,  0.2590,  ..., -0.4861, -0.0731,  0.0756],
         [-0.2478,  0.2420,  0.2552,  ..., -0.4877, -0.0727,  0.0754],
         [-0.2472,  0.2420,  0.2561,  ..., -0.4874, -0.0733,  0.0765]],

        [[ 0.0768, -0.1234,  0.1534,  ..., -0.2518, -0.2571,  0.1602],
         [-0.2419, -0.2821,  0.1962,  ..., -0.0172, -0.2960,  0.3679],
         [ 0.0911, -0.1437,  0.3412,  ...,  0.2526, -0.1780,  0.2619]]],
       grad_fn=<NativeLayerNormBackward0>)
from kobert_transformers import get_tokenizer

tokenizer = get_tokenizer()

tokenizer.tokenize('[CLS] 한국어 모델을 공유합니다. [SEP]')
---------------------------------------------------------
# 결과
['[CLS]', '▁한국', '어', '▁모델', '을', '▁공유', '합니다', '.', '[SEP]']

--------------------------------------------------------
tokenizer.convert_tokens_to_ids(['[CLS]', '▁한국', '어', '▁모델', '을', '▁공유', '합니다', '.', '[SEP]'])
--------------------------------------------------------
# 결과
[2, 4958, 6855, 2046, 7088, 1050, 7843, 54, 3]

--------------------------------------------------------
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

def get_cls_token(sentence):
    model.eval()
    tokenized_sent = tokenizer(
        sentence,
        return_tensors='pt',
        truncation=True,
        add_special_tokens=True,
        max_length=128
    )

    input_ids = tokenized_sent['input_ids']
    attention_mask = tokenized_sent['attention_mask']
    token_type_ids = tokenized_sent['token_type_ids']

    with torch.no_grad():
        output = model(input_ids, attention_mask, token_type_ids)

    cls_output = output[1]
    cls_token = cls_output.detach().cpu().numpy()

    return cls_token
    
    
def predict(query, candidates):
    candidates_cls = []

    for cand in candidates:
        cand_cls = get_cls_token(cand)
        candidates_cls.append(cand_cls)

    candidates_cls = np.array(candidates_cls).squeeze(axis=1)

    query_cls = get_cls_token(query)
    similarity_list = cosine_similarity(query_cls, candidates_cls)

    target_idx = np.argmax(similarity_list)
    return candidates[target_idx]
    
query = '너 요즘 바뻐?'
query_cls_hidden = get_cls_token(query)
print(query_cls_hidden)
print(query_cls_hidden.shape)
--------------------------------------------------------------------------------
# 결과
[[ 1.97712686e-02 -4.20350656e-02 -8.08150396e-02  1.43184792e-02
  -8.87244284e-01  9.76347804e-01 -3.12679619e-01  1.08309001e-01
  -7.70816207e-02 -7.28073344e-02 -5.49437761e-01 -1.36874225e-02
  -4.49551875e-03  1.16062753e-01 -6.47512078e-03  6.94795251e-01
  -9.89945292e-01  6.73853531e-02 -4.00665700e-02 -2.03489568e-02
   6.67853653e-02 -2.53229076e-03 -4.10425514e-02  5.43793917e-01
   3.53782102e-02  7.47534931e-01 -9.01233375e-01 -9.19028968e-02
  -8.82182658e-01  9.46917832e-02  9.30823743e-01  6.68540180e-01
  -4.72913012e-02 -4.87457700e-02 -8.58902276e-01 -1.67718023e-01
  -3.68323103e-02 -1.58498865e-02 -9.75222409e-01 -5.68215996e-02
  -4.96404991e-02  2.40122643e-03  3.73748466e-02  9.78760481e-01
  -9.96391356e-01  6.99634925e-02 -5.84157288e-01  2.57410798e-02
  -9.95675206e-01  9.74204183e-01 -2.00891510e-01 -9.99807358e-01
  -9.80349720e-01  1.94250289e-02 -7.49887675e-02 -7.94397116e-01
   1.02787115e-01  6.21608198e-01 -1.45016257e-02  8.89345035e-02
   1.19152609e-02  5.47532775e-02  9.20245707e-01 -9.87935603e-01
   3.01292688e-02 -2.65660342e-02 -1.06672823e-01 -3.09796836e-02
   9.87809803e-03  1.64681561e-02  9.40921783e-01 -1.11121545e-02
   2.79974818e-01 -6.98985830e-02 -5.59643507e-01 -2.04323381e-02
  -1.22432746e-01 -1.63029637e-02  1.21791765e-01  5.74307404e-02
   1.61552262e-02 -9.55803871e-01 -5.35652220e-01  4.11142349e-01
  -4.79983203e-02  5.45244478e-02  6.68844581e-02 -5.60483709e-03
   9.91440356e-01 -6.08555794e-01 -2.80754305e-02 -1.51853645e-02
   4.41351347e-02  4.45505418e-02 -9.37977374e-01  3.68432887e-02
  -3.91986258e-02  3.95908728e-02  1.30553180e-02  8.80397081e-01
   2.85759736e-02 -6.91245422e-02 -4.83701304e-02  7.50415623e-02
  -9.57032561e-01  3.85181233e-02 -4.11867686e-02  7.57834502e-03
  -2.58947704e-02  2.30288096e-02 -6.81483187e-03 -4.77751344e-02
   9.02028978e-01  5.78373037e-02 -5.13115942e-01  4.66835760e-02
  -1.21096531e-02 -7.28294328e-02  8.30217451e-03  7.80005828e-02
   1.20663255e-01  6.35054111e-02 -9.91669655e-01  8.69436026e-01
  -9.99854863e-01 -9.14234854e-03 -8.98460507e-01  3.41066644e-02
  -1.39556751e-02 -2.11755675e-03  5.12874275e-02  3.34920734e-02
   1.09835798e-02  4.99966703e-02  6.09446652e-02  4.96051162e-02
  -9.28291023e-01 -5.49920909e-02  8.07364509e-02  7.13707507e-02
   3.20894308e-02  5.07209972e-02  3.52416873e-01  9.20596957e-01
  -1.33979181e-02  2.50120815e-02  1.47373984e-02  7.09045455e-02
  -1.00761196e-02  9.00735240e-03 -9.46521699e-01  6.42012879e-02
  -8.47015977e-01  2.35851780e-02  1.11326121e-03  3.85401882e-02
   8.96188095e-02 -3.28979827e-02 -2.55992524e-02 -8.79032433e-01
   7.56769702e-02  4.37905006e-02 -2.70840041e-02  9.90387022e-01
  -1.68378185e-02  1.21676736e-01  5.42990305e-02  4.73825961e-01
  -3.73174399e-02  9.82729532e-03 -5.52667454e-02 -8.36797524e-03
  -9.21916842e-01 -7.08967924e-01  1.21669851e-01  9.99682784e-01
  -1.97477359e-02 -8.77339840e-01  3.16804014e-02 -1.83018073e-02
  -5.06146718e-03  9.76287425e-01  2.90160209e-01 -9.99827921e-01
   3.17350924e-02  3.96242291e-02 -1.32720862e-02 -3.51913124e-02
   5.21909967e-02  1.99955646e-02 -1.75941493e-02  1.05277762e-01
  -6.90918118e-02 -8.85317847e-02 -1.31769195e-01 -4.25230190e-02
   2.88107157e-01  4.25480725e-03  1.61117706e-02 -6.21331781e-02
   2.75977943e-02  4.47717100e-01  2.58003846e-02 -3.00958101e-02
   8.87380481e-01 -3.50626260e-02  1.04203206e-02  3.23608406e-02
  -5.06323397e-01  9.14470911e-01 -2.15907674e-03  3.48478965e-02
   8.74281347e-01  8.57237041e-01 -5.60339205e-02  4.88039665e-02
  -2.96012815e-02  5.85433960e-01  8.85599554e-01 -9.69411671e-01
  -8.37742984e-01 -9.04358923e-03  5.83447795e-03 -2.27306709e-02
  -6.48913439e-03  2.71145374e-01  7.31240571e-01 -9.81174767e-01
  -3.50158215e-02  7.24899629e-03 -9.91511464e-01 -2.15140115e-02
   3.07669397e-02 -2.28972971e-01  4.12351251e-01 -2.97829956e-01
   2.97683060e-01 -1.22985933e-02 -4.94735837e-02 -2.75764521e-02
   9.01489854e-01 -8.02892540e-03  8.63082767e-01 -8.56264457e-02
   4.14135635e-01  8.60826224e-02 -4.89611700e-02  9.95140016e-01
   1.66754365e-01  9.85895157e-01  1.45100756e-02  5.71456067e-02
  -6.46065921e-02  2.03104615e-02  1.82742663e-02  8.81142393e-02
   9.78763700e-01  6.46322668e-02  9.87488270e-01 -8.65316808e-01
   2.63852905e-02  2.06853095e-02 -4.01435141e-03 -4.11329977e-02
   6.69814870e-02 -2.93426607e-02  7.46286437e-02 -4.72679436e-01
  -5.91187954e-01  5.13812751e-02  2.87640933e-02 -7.29547590e-02
   4.40406576e-02 -2.94400483e-01  6.77515388e-01 -8.84745479e-01
   9.12754536e-02  5.19075897e-03  7.55682886e-02  2.60925367e-02
  -5.04910909e-02  1.58710003e-01 -9.80808735e-01 -3.30212228e-02
   4.70192023e-02 -2.77179573e-02 -6.82025254e-02 -8.56657103e-02
   5.77225238e-02  2.73138210e-02  9.69972134e-01  7.23451674e-02
   5.73495805e-01 -5.53214066e-02 -3.43015306e-02 -5.73314987e-02
  -3.60760421e-01  9.92592335e-01 -4.29719239e-02 -2.54395884e-03
  -3.73231783e-03  1.75964192e-01 -1.12269009e-02  2.51382426e-03
   9.99624670e-01 -9.35694110e-03 -5.24946690e-01  4.96206462e-01
   8.27659965e-02 -1.16833746e-01 -4.28614587e-01  3.59543599e-02
  -3.91330011e-02  9.48929548e-01  9.99915421e-01 -6.58894307e-04
   1.53926685e-02  4.19992395e-02  9.35816541e-02 -1.50320724e-01
  -3.90121387e-03  1.39818126e-02  3.35921608e-02  3.17463838e-02
   6.70463964e-02 -2.85734795e-02 -6.41714156e-01  1.05718933e-02
   1.15118856e-02  1.43086705e-02 -4.05126475e-02 -8.23391229e-03
  -8.90455067e-01 -1.34508967e-01 -9.31038022e-01  4.81921219e-04
   1.24618085e-02  3.91753577e-02  2.55453009e-02  8.59406233e-01
  -2.67723296e-02 -4.12802063e-02  7.50546932e-01 -9.87331122e-02
  -9.46730435e-01 -9.98687983e-01 -3.62062082e-02  8.82269263e-01
  -2.75548338e-03  3.37444395e-02 -3.71669382e-02 -1.42455265e-01
  -3.06474626e-01  2.31448133e-02  1.50695606e-03 -4.18973938e-02
  -6.22345269e-01  9.99812782e-01 -5.05036525e-02  2.84895916e-02
  -3.53690051e-02  9.13175289e-03 -5.22863746e-01 -9.89726126e-01
   7.09966421e-02 -2.42635924e-02 -9.89548385e-01  6.83806464e-02
   4.07988206e-02  1.08705340e-02 -4.37119544e-01 -5.63813210e-01
   4.74739075e-02  3.54076065e-02 -7.31701329e-02 -4.36487012e-02
  -9.32106495e-01  4.00953829e-01 -1.69779547e-02  3.44160385e-02
  -9.99760509e-01 -8.16055760e-02  9.42515254e-01  5.87241352e-02
   7.71811679e-02  2.63805557e-02 -9.58483201e-03  1.69935040e-02
  -3.50716114e-02 -2.74949917e-03 -7.07111001e-01  9.99636531e-01
  -1.39362991e-01 -1.13253608e-01 -1.54854292e-02  2.27557030e-02
  -3.57569866e-02  4.39597070e-02 -4.65914048e-02  1.51156662e-02
   8.78148433e-03 -3.01217958e-02 -1.23016601e-02  9.32075977e-01
  -6.01027906e-01  9.51776803e-01  1.83693722e-01  5.11921942e-02
  -4.39717531e-01  7.78894648e-02  2.72273226e-03 -9.82761145e-01
  -9.57825184e-01 -2.36347094e-02  2.47249845e-02 -5.78276753e-01
   3.44470004e-03 -9.28235520e-03 -4.05462086e-02 -5.06315641e-02
   1.81933586e-02 -8.90301228e-01  5.07990755e-02 -8.97732079e-01
   8.36622119e-01 -9.70130861e-02 -6.55984953e-02 -5.59528232e-01
  -3.37564759e-02 -1.53997242e-01 -3.12335849e-01 -9.79732350e-03
  -3.76110762e-01 -6.79448366e-01 -5.14980197e-01  4.50537540e-02
   8.84905875e-01 -2.19864249e-01  8.45167458e-01  9.02457833e-01
   5.73191822e-01 -1.03603192e-02  4.08219248e-02  9.99873698e-01
  -5.77480793e-01  7.31348842e-02 -6.93641305e-02  5.33511043e-01
   8.34202409e-01  6.06251322e-02 -9.22108650e-01 -4.53240238e-02
  -9.21294987e-01 -1.56663556e-03  1.53640127e-02  5.07251685e-03
  -6.68962225e-02 -3.83358359e-01 -8.48755892e-03  6.85555100e-01
  -8.37420858e-03 -4.52448875e-02  5.83141595e-02 -1.70860831e-02
  -5.92008941e-02  9.21067297e-01  9.18299332e-02 -8.46074104e-01
  -1.67632401e-01 -9.61074114e-01  1.10302772e-02 -5.71848713e-02
   4.48717847e-02  7.93937147e-01 -2.97717024e-02 -7.98422933e-01
  -3.50071527e-02 -2.32628584e-02  9.07549739e-01  1.07617840e-01
  -9.99721169e-01 -9.81233180e-01 -8.30599010e-01  1.21828495e-02
   7.36793876e-01 -8.75474811e-01  5.85044473e-02  3.26856710e-02
  -1.54623797e-03 -9.97475147e-01  7.12855831e-02 -1.55170271e-02
   2.20191497e-02 -6.73397910e-03 -8.83155882e-01  3.59170139e-02
   4.71343279e-01  3.00890133e-02  2.49721464e-02 -1.51230261e-01
  -9.28689361e-01 -4.23124582e-02 -2.99454704e-02 -6.59474611e-01
  -1.81930140e-02  2.39341483e-01  1.19231351e-01 -7.52601743e-01
   9.88331318e-01  7.62676820e-03  7.32322186e-02 -3.86127979e-01
  -1.04531394e-02 -5.44525385e-02 -8.44074070e-01  8.42035592e-01
   1.12028763e-01 -1.75491627e-02  2.72481679e-03 -5.30968122e-02
  -4.97397082e-03  1.58805046e-02 -3.42542864e-02 -7.41725385e-01
  -4.15326655e-02  7.68830299e-01  1.60676762e-02 -1.04533322e-02
   3.49776261e-02  5.09096682e-02 -9.76349711e-01 -8.50231349e-02
  -9.76854980e-01  2.29704715e-02  8.29891384e-01  2.15301756e-02
  -5.02061881e-02  1.86863542e-02 -5.37561029e-02 -9.99018133e-01
  -1.97572690e-02 -9.99930263e-01  1.22419959e-02  6.10845685e-01
  -9.94518250e-02 -4.60792035e-02  6.06065273e-01  2.74291057e-02
  -8.19775522e-01 -8.31643105e-01 -1.38606638e-01 -9.99899149e-01
  -9.27114248e-01 -8.00279498e-01 -2.65904292e-02  5.39314866e-01
   1.15309991e-02  5.63823581e-01  4.38051447e-02 -5.28342985e-02
  -2.29860917e-02 -9.44950953e-02  2.46265158e-02 -2.15100087e-02
  -8.75434279e-01  2.61015873e-02  7.20122755e-02 -6.39877245e-02
   3.65058780e-02 -4.86739069e-01 -1.56598398e-04 -2.36861873e-02
   1.02454529e-03 -9.59654152e-01  2.87306700e-02  3.06227454e-03
   1.77779570e-02 -3.77629489e-01  3.26561704e-02 -3.10274512e-02
   3.51714417e-02 -9.44145322e-01 -8.99597555e-02 -1.04208075e-01
   6.93175554e-01 -7.50541538e-02  6.90494403e-02 -8.82190228e-01
   2.50913389e-02  6.54611826e-01  5.60183823e-01  7.07554519e-01
   9.60617781e-01  6.14134558e-02  5.77286065e-01 -2.83782817e-02
  -3.36266398e-01  3.31814885e-02  8.99018168e-01  1.10603698e-01
  -5.64125143e-02  5.42355254e-02 -9.56353784e-01 -8.88355374e-02
  -9.99193132e-01 -2.48066895e-02 -9.57826972e-01 -9.72862005e-01
   2.94864941e-02  3.90534215e-02 -7.86748946e-01 -4.07382578e-01
  -5.29809184e-02 -3.25192958e-02  2.30216999e-02 -3.80927115e-03
   5.05167097e-02 -8.55317116e-01 -2.33487133e-03  4.45174187e-01
  -3.50024849e-02  1.75286981e-03  4.00325749e-03 -9.21898410e-02
   3.99731752e-03 -1.36185333e-01 -4.03721519e-02  9.40170705e-01
   8.97302032e-02 -9.09734964e-01  7.61120081e-01 -2.85653793e-03
  -6.37132585e-01 -6.98854253e-02 -4.47482951e-02  4.74471133e-03
   9.08056140e-01 -2.85296198e-02  9.99938309e-01  9.73304212e-01
  -8.72908831e-01  5.02233446e-01  4.86812443e-02  2.82573223e-01
   6.34036809e-02 -9.62208569e-01 -5.02869189e-02 -8.60814095e-01
  -6.44549355e-02  2.71886345e-02  9.14090127e-02 -1.74800158e-01
  -9.93987024e-01  3.31035675e-03  3.66920717e-02 -1.21775903e-01
  -8.67323816e-01 -7.44413137e-02 -9.99465227e-01 -1.81429219e-02
   8.87998164e-01  8.39024246e-01  5.19102626e-02 -1.27681950e-02
   9.94056821e-01  6.85030162e-01 -8.53933930e-01  3.65198925e-02
  -1.86123257e-03 -9.46935825e-03 -9.58351493e-01  4.58785705e-02
  -7.18914643e-02 -8.44064448e-03  5.83787076e-02  3.54283229e-02
  -5.40311672e-02  5.03920376e-01  1.63402393e-01  7.98747167e-02
   9.56079438e-02  8.89035106e-01 -6.37682676e-02 -3.11780423e-02
   5.65435737e-03 -6.76189959e-01  2.26582512e-02  9.99777913e-01
   5.42106211e-01  8.81658018e-01  9.93905306e-01  5.51105244e-03
   3.81119438e-02  8.49794149e-01  5.89373708e-01 -1.22708017e-02
  -1.56028092e-01 -9.92524326e-01 -2.87821013e-02 -1.58514440e-01
   9.09859657e-01  2.19297037e-01 -9.99814630e-01  1.32726338e-02
   9.70062315e-01  9.70911145e-01 -6.44348443e-01  8.52031946e-01
  -9.63064015e-01 -5.76023698e-01 -5.83868250e-02  2.64066234e-02
  -7.20684350e-01  2.16158084e-03 -4.64602932e-03  9.99902904e-01
   4.17144261e-02  4.54902276e-02  3.64559367e-02  5.55259828e-03
   9.59032774e-03 -7.67509686e-03  5.07940575e-02  8.90136123e-01
   2.93284263e-02 -6.64203092e-02 -2.92305555e-02 -9.98543262e-01
  -2.52464801e-01 -4.86114502e-01 -2.66075823e-02 -1.43679734e-02
  -9.95312929e-01 -5.84236123e-02  4.27999906e-03 -8.46150890e-02
  -2.91191936e-01  4.45333533e-02 -5.85757196e-02  1.04175992e-01
  -3.75012099e-03 -1.49845872e-02  1.33885713e-02 -5.59754819e-02
  -6.55457899e-02  9.34130967e-01 -7.00237095e-01  4.75338362e-02
   2.88320193e-03  9.87558126e-01 -2.82196980e-03  6.53676987e-02
  -5.88703454e-01 -9.38618541e-01 -8.60266924e-01  8.97187293e-01
   1.31906364e-02 -6.95280254e-01 -4.05457728e-02 -9.66625988e-01
   2.19611265e-02  8.83074343e-01 -5.53830385e-01 -9.96755064e-03
  -9.90602791e-01  1.55075684e-01 -7.11186007e-02 -2.03615054e-02]]
(1, 768)
-------------------------------------------------------------------------
sample_query = '너 요즘 바뻐?'
sample_candidates = ['바쁘면 가라', '아니 별로 안바뻐', '3인조 여성 그룹', '오늘은 여기까지']

predicted_answer = predict(query, sample_candidates)
print(f'결과: {predicted_answer}')
-------------------------------------------------------------------------
# 결과
결과: 아니 별로 안바뻐
---------------------------------------------------------------------------
user_query = '나 오늘 헤어졌어'
predicted_answer = predict(user_query, response_candidates)
print(f'결과: {predicted_answer}')
-------------------------------------------------------------------------------
# 결과
결과: 겨울에는 귤 먹으면서 집에 있는게 최고죠

--------------------------------------------------------------------------------
response_candidates = random.sample(answer_list, 100)

user_query = '나 오늘 너무 힘들어'
predicted_answer = predict(user_query, response_candidates)
print(f'결과: {predicted_answer}')
------------------------------------------------------------------------------------
# 결과
결과: 썸도 좋지요.

-------------------------------------------------------------------------------------
end = 1
while end == 1:
    sentence = input('하고싶은 말을 입력하세요: ')
    if len(sentence) == 0:
        break
    predicted_answer = predict(sentence, response_candidates)
    print(predicted_answer)
    print('\n')
-------------------------------------------------------------------------------------
# 결과
하고싶은 말을 입력하세요: 야 너 뭐해?
책임을 지기 싫은 건지 생각해보세요.


하고싶은 말을 입력하세요: 그만해
카페라도 가서 쉬다 오면 어떨까요

 

728x90
반응형