14 분 소요

텍스트 유사도 (2)

  • 캐글 ‘Quora question pairs dataset’으로 중복/ 유사도 검사
  • 텐서플로2와 머신러닝으로 시작하는 자연어처리 p230
  • 캐글 api 사용해서 불러와도 되지만, 구글 로그인으로 접속해서 다운받아서 진행했다.
  • 데이터 전처리 - 학습, 평가 데이터 분리, 모델링(XGboost), 예측 순서로 진행
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

%matplotlib inline
train_data = pd.read_csv('C:/Users/silan/Python/quora-question-pairs/train.csv/'+'train.csv')
train_data.head()
id qid1 qid2 question1 question2 is_duplicate
0 0 1 2 What is the step by step guide to invest in sh... What is the step by step guide to invest in sh... 0
1 1 3 4 What is the story of Kohinoor (Koh-i-Noor) Dia... What would happen if the Indian government sto... 0
2 2 5 6 How can I increase the speed of my internet co... How can Internet speed be increased by hacking... 0
3 3 7 8 Why am I mentally very lonely? How can I solve... Find the remainder when [math]23^{24}[/math] i... 0
4 4 9 10 Which one dissolve in water quikly sugar, salt... Which fish would survive in salt water? 0
print('파일의 크기: ')

for file in os.listdir('C:/Users/silan/Python/quora-question-pairs/'):
    if 'csv' in file and 'zip' not in file:
        print(file.ljust(30) + str(round(os.path.getsize('C:/Users/silan/Python/quora-question-pairs/'+file)/1000000, 2)) + 'MB')
파일의 크기: 
test.csv                      314.02MB
train.csv                     0.0MB
print('전체 학습 데이터의 개수 {}'.format(len(train_data)))
전체 학습 데이터의 개수 404290
train_data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 404290 entries, 0 to 404289
Data columns (total 6 columns):
 #   Column        Non-Null Count   Dtype 
---  ------        --------------   ----- 
 0   id            404290 non-null  int64 
 1   qid1          404290 non-null  int64 
 2   qid2          404290 non-null  int64 
 3   question1     404289 non-null  object
 4   question2     404288 non-null  object
 5   is_duplicate  404290 non-null  int64 
dtypes: int64(4), object(2)
memory usage: 18.5+ MB
# 중복 여부 파악

train_set = pd.Series(train_data['question1'].tolist() +
                      train_data['question2'].tolist()).astype(str)

train_set.head()
0    What is the step by step guide to invest in sh...
1    What is the story of Kohinoor (Koh-i-Noor) Dia...
2    How can I increase the speed of my internet co...
3    Why am I mentally very lonely? How can I solve...
4    Which one dissolve in water quikly sugar, salt...
dtype: object
print('학습 데이터 전체 수: {}'.format(len(np.unique(train_set))))
print('반복되는 질문 수: {}'.format(np.sum(train_set.value_counts() > 1))) # 2개 이상인 것
학습 데이터 전체 수: 537361
반복되는 질문 수: 111873

시각화: 이상치 및 중복질문 개수 확인

  • figsize (가로, 세로)
  • bins : 히스토그램에서 버킷 범위
  • range: x 값의 범위
  • alpha: 그래프 색상 투명도
  • color: 색상
  • label: 축 레이블
# plt.figure(figsize=(20, 10))
# plt.hist(train_set.value_counts(), bins=50, )
# plt.yscale('log', nonposy='clip')

# histogram on linear scale
# plt.subplot(211)
# hist, bins, _ = plt.hist(train_set.value_counts(), bins=50)

# histogram on log scale. 
# Use non-equal bin sizes, such that they look equal on log scale.
logbins = np.logspace(np.log10(bins[0]),np.log10(bins[-1]),len(bins))
plt.subplot(212)
plt.hist(train_set.value_counts(), bins=logbins, alpha=0.5, color='r', label='word')
plt.yscale('log')
# plt.title('Log-histogram of Questions appearance counts')
# plt.xlabel('# of ocurrences of question')
# plt.ylabel('# of questions')

plt.show()

image

plt.figure(figsize=(12,6))
plt.boxplot([train_set.value_counts()], labels=['counts'], showmeans=True)
{'whiskers': [<matplotlib.lines.Line2D at 0x248130dd2b0>,
  <matplotlib.lines.Line2D at 0x248130dd580>],
 'caps': [<matplotlib.lines.Line2D at 0x248130dd910>,
  <matplotlib.lines.Line2D at 0x248130ddb20>],
 'boxes': [<matplotlib.lines.Line2D at 0x248130d6fa0>],
 'medians': [<matplotlib.lines.Line2D at 0x248130dddf0>],
 'fliers': [<matplotlib.lines.Line2D at 0x248130e43d0>],
 'means': [<matplotlib.lines.Line2D at 0x248130e4100>]}

image

전처리: 라벨 개수 균형 맞추기

  • 중복/중복이 아닌 데이터 중 중복 데이터가 훨씬 많아서
  • 데이터 개수를 줄인 후 분석 진행
import re
import json

train_pos_data = train_data.loc[train_data['is_duplicate'] == 1]
train_neg_data = train_data.loc[train_data['is_duplicate'] == 0]

class_difference = len(train_neg_data) - len(train_pos_data)
sample_frac = 1- (class_difference / len(train_neg_data))

train_neg_data = train_neg_data.sample(frac = sample_frac) # 샘플링

print(f'중복 질문 개수: {len(train_pos_data)}')
print(f'중복이 아닌 질문 개수: {len(train_neg_data)}')  
중복 질문 개수: 149263
중복이 아닌 질문 개수: 149263
train_data = pd.concat([train_neg_data, train_pos_data])
train_data.head()
id qid1 qid2 question1 question2 is_duplicate
187868 187868 286229 79093 What is the corporate culture like at Demandwa... What is the corporate culture like at Quantum?... 0
184405 184405 281752 281753 How do I fix a location unavailable error in t... Does the Android Device Manager app use the in... 0
370455 370455 501024 501025 What are some things new employees should know... What are some things new employees should know... 0
332792 332792 366013 306645 Can I get pregnant 2 days after ovulation? Can you get pregnant a day after ovulation? 0
130994 130994 210113 210114 How is characterization used in "The Great Gat... How is carelessness portrayed in "The Great Ga... 0
change_filter = re.compile('[a-z]+')

q1 = [str(s) for s in train_data['question1']]
q2 = [str(s) for s in train_data['question2']]

filtered_q1 = list()
filtered_q2 = list()

for q in q1:
    filtered_q1.append(re.sub(change_filter,"", q).lower())

for q in q2:
    filtered_q2.append(re.sub(change_filter,"", q).lower())

전체 단어 사전 만들기

!pip install tensorflow
Collecting tensorflow
  Downloading tensorflow-2.10.0-cp39-cp39-win_amd64.whl (455.9 MB)
Collecting tensorflow-io-gcs-filesystem>=0.23.1
  Downloading tensorflow_io_gcs_filesystem-0.27.0-cp39-cp39-win_amd64.whl (1.5 MB)
Collecting astunparse>=1.6.0
  Using cached astunparse-1.6.3-py2.py3-none-any.whl (12 kB)
Requirement already satisfied: numpy>=1.20 in c:\anaconda\lib\site-packages (from tensorflow) (1.21.5)
Requirement already satisfied: h5py>=2.9.0 in c:\anaconda\lib\site-packages (from tensorflow) (3.6.0)
Requirement already satisfied: packaging in c:\anaconda\lib\site-packages (from tensorflow) (21.3)
Collecting keras-preprocessing>=1.1.1
  Using cached Keras_Preprocessing-1.1.2-py2.py3-none-any.whl (42 kB)
Collecting opt-einsum>=2.3.2
  Using cached opt_einsum-3.3.0-py3-none-any.whl (65 kB)
Collecting libclang>=13.0.0
  Downloading libclang-14.0.6-py2.py3-none-win_amd64.whl (14.2 MB)
Collecting gast<=0.4.0,>=0.2.1
  Using cached gast-0.4.0-py3-none-any.whl (9.8 kB)
Collecting tensorboard<2.11,>=2.10
  Downloading tensorboard-2.10.1-py3-none-any.whl (5.9 MB)
Requirement already satisfied: wrapt>=1.11.0 in c:\anaconda\lib\site-packages (from tensorflow) (1.12.1)
Requirement already satisfied: typing-extensions>=3.6.6 in c:\anaconda\lib\site-packages (from tensorflow) (4.1.1)
Collecting google-pasta>=0.1.1
  Using cached google_pasta-0.2.0-py3-none-any.whl (57 kB)
Requirement already satisfied: six>=1.12.0 in c:\anaconda\lib\site-packages (from tensorflow) (1.16.0)
Collecting flatbuffers>=2.0
  Downloading flatbuffers-22.9.24-py2.py3-none-any.whl (26 kB)
Requirement already satisfied: protobuf<3.20,>=3.9.2 in c:\anaconda\lib\site-packages (from tensorflow) (3.19.1)
Collecting termcolor>=1.1.0
  Downloading termcolor-2.0.1-py3-none-any.whl (5.4 kB)
Collecting tensorflow-estimator<2.11,>=2.10.0
  Downloading tensorflow_estimator-2.10.0-py2.py3-none-any.whl (438 kB)
Requirement already satisfied: setuptools in c:\anaconda\lib\site-packages (from tensorflow) (61.2.0)
Requirement already satisfied: grpcio<2.0,>=1.24.3 in c:\anaconda\lib\site-packages (from tensorflow) (1.42.0)
Collecting absl-py>=1.0.0
  Downloading absl_py-1.2.0-py3-none-any.whl (123 kB)
Collecting keras<2.11,>=2.10.0
  Downloading keras-2.10.0-py2.py3-none-any.whl (1.7 MB)
Requirement already satisfied: wheel<1.0,>=0.23.0 in c:\anaconda\lib\site-packages (from astunparse>=1.6.0->tensorflow) (0.37.1)
Requirement already satisfied: google-auth<3,>=1.6.3 in c:\anaconda\lib\site-packages (from tensorboard<2.11,>=2.10->tensorflow) (1.33.0)
Collecting tensorboard-data-server<0.7.0,>=0.6.0
  Using cached tensorboard_data_server-0.6.1-py3-none-any.whl (2.4 kB)
Collecting tensorboard-plugin-wit>=1.6.0
  Using cached tensorboard_plugin_wit-1.8.1-py3-none-any.whl (781 kB)
Requirement already satisfied: werkzeug>=1.0.1 in c:\anaconda\lib\site-packages (from tensorboard<2.11,>=2.10->tensorflow) (2.0.3)
Requirement already satisfied: markdown>=2.6.8 in c:\anaconda\lib\site-packages (from tensorboard<2.11,>=2.10->tensorflow) (3.3.4)
Requirement already satisfied: requests<3,>=2.21.0 in c:\anaconda\lib\site-packages (from tensorboard<2.11,>=2.10->tensorflow) (2.27.1)
Collecting google-auth-oauthlib<0.5,>=0.4.1
  Using cached google_auth_oauthlib-0.4.6-py2.py3-none-any.whl (18 kB)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in c:\anaconda\lib\site-packages (from google-auth<3,>=1.6.3->tensorboard<2.11,>=2.10->tensorflow) (4.2.2)
Requirement already satisfied: pyasn1-modules>=0.2.1 in c:\anaconda\lib\site-packages (from google-auth<3,>=1.6.3->tensorboard<2.11,>=2.10->tensorflow) (0.2.8)
Requirement already satisfied: rsa<5,>=3.1.4 in c:\anaconda\lib\site-packages (from google-auth<3,>=1.6.3->tensorboard<2.11,>=2.10->tensorflow) (4.7.2)
Collecting requests-oauthlib>=0.7.0
  Using cached requests_oauthlib-1.3.1-py2.py3-none-any.whl (23 kB)
Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in c:\anaconda\lib\site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard<2.11,>=2.10->tensorflow) (0.4.8)
Requirement already satisfied: idna<4,>=2.5 in c:\anaconda\lib\site-packages (from requests<3,>=2.21.0->tensorboard<2.11,>=2.10->tensorflow) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in c:\anaconda\lib\site-packages (from requests<3,>=2.21.0->tensorboard<2.11,>=2.10->tensorflow) (2021.10.8)
Requirement already satisfied: charset-normalizer~=2.0.0 in c:\anaconda\lib\site-packages (from requests<3,>=2.21.0->tensorboard<2.11,>=2.10->tensorflow) (2.0.4)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in c:\anaconda\lib\site-packages (from requests<3,>=2.21.0->tensorboard<2.11,>=2.10->tensorflow) (1.26.9)
Collecting oauthlib>=3.0.0
  Downloading oauthlib-3.2.1-py3-none-any.whl (151 kB)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in c:\anaconda\lib\site-packages (from packaging->tensorflow) (3.0.4)
Installing collected packages: oauthlib, requests-oauthlib, tensorboard-plugin-wit, tensorboard-data-server, google-auth-oauthlib, absl-py, termcolor, tensorflow-io-gcs-filesystem, tensorflow-estimator, tensorboard, opt-einsum, libclang, keras-preprocessing, keras, google-pasta, gast, flatbuffers, astunparse, tensorflow
Successfully installed absl-py-1.2.0 astunparse-1.6.3 flatbuffers-22.9.24 gast-0.4.0 google-auth-oauthlib-0.4.6 google-pasta-0.2.0 keras-2.10.0 keras-preprocessing-1.1.2 libclang-14.0.6 oauthlib-3.2.1 opt-einsum-3.3.0 requests-oauthlib-1.3.1 tensorboard-2.10.1 tensorboard-data-server-0.6.1 tensorboard-plugin-wit-1.8.1 tensorflow-2.10.0 tensorflow-estimator-2.10.0 tensorflow-io-gcs-filesystem-0.27.0 termcolor-2.0.1
import tensorflow as tf

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

tokenizer = Tokenizer()
tokenizer.fit_on_texts(filtered_q1 + filtered_q2)

question1_sequence = tokenizer.texts_to_sequences(filtered_q1)
question2_sequence = tokenizer.texts_to_sequences(filtered_q2)

#길이 맞추기

MAX_SEQUENCE_LENGTH = 31 # 평균길이로 설정, 변경 가능

q1_data = pad_sequences(question1_sequence, maxlen=MAX_SEQUENCE_LENGTH, padding='post')
q2_data = pad_sequences(question2_sequence, maxlen=MAX_SEQUENCE_LENGTH, padding='post')
# 단어사전, 라벨 값 확인

word_vocab = {}
word_vocab = tokenizer.word_index
word_vocab["<PAD>"] = 0

labels = np.array(train_data['is_duplicate'],dtype=int)

print(f'Shape of question1 data: {q1_data.shape}')
print(f'Shape of question2 data: {q2_data.shape}')
print(f'Shape of label: {labels.shape}')
print(f'Words of index: {len(word_vocab)}')      
Shape of question1 data: (298526, 31)
Shape of question2 data: (298526, 31)
Shape of label: (298526,)
Words of index: 11870
# 전처리 끝난 데이터 저장하기

data_configs = {}
data_configs['vocab'] = word_vocab
data_configs['vocab_size'] = len(word_vocab)

TRAIN_Q1_DATA = 'q1_train.npy'
TRAIN_Q2_DATA = 'q2_train.npy'
TRAIN_LABEL_DATA = 'label_train.npy'
DATA_CONFIGS = 'data_configs.npy'


path = 'C:/Users/silan/Python/quora-question-pairs/'
np.save(open(path + TRAIN_Q1_DATA, 'wb'), q1_data)
np.save(open(path + TRAIN_Q2_DATA, 'wb'), q2_data)
np.save(open(path + TRAIN_LABEL_DATA, 'wb'), labels)

json.dump(data_configs, open(path + DATA_CONFIGS, 'w'))

TEST 데이터 전처리

path = 'C:/Users/silan/Python/quora-question-pairs/'
test_data = pd.read_csv(path +'test.csv', encoding = 'utf-8')
valid_ids = [type(x) == int for x in test_data.test_id]
test_data = test_data[valid_ids].drop_duplicates() # 중복값 제거

change_filter = re.compile('[a-z]+')

q1 = [str(s) for s in test_data['question1']]
q2 = [str(s) for s in test_data['question2']]

filtered_q1 = list()
filtered_q2 = list()

for q in q1:
    filtered_q1.append(re.sub(change_filter,"", q).lower())

for q in q2:
    filtered_q2.append(re.sub(change_filter,"", q).lower())

# 토큰화    
tokenizer = Tokenizer()
tokenizer.fit_on_texts(filtered_q1 + filtered_q2)

test_data_question1_sequence = tokenizer.texts_to_sequences(filtered_q1)
test_data_question2_sequence = tokenizer.texts_to_sequences(filtered_q2)

#길이 맞추기

MAX_SEQUENCE_LENGTH = 31 # 평균길이로 설정, 변경 가능

q1_data = pad_sequences(test_data_question1_sequence, maxlen=MAX_SEQUENCE_LENGTH, padding='post')
q2_data = pad_sequences(test_data_question2_sequence, maxlen=MAX_SEQUENCE_LENGTH, padding='post')

test_ids = np.array(test_data['test_id'])

print(f'Shape of question1 data: {q1_data.shape}')
print(f'Shape of question2 data: {q2_data.shape}')
print(f'Shape of Ids: {test_ids.shape}')

TEST_Q1_DATA = 'q1_test.npy'
TEST_Q2_DATA = 'q2_test.npy'
TEST_ID_DATA = 'test_ids.npy'

path = 'C:/Users/silan/Python/quora-question-pairs/'
np.save(open(path + TEST_Q1_DATA, 'wb'), q1_data)
np.save(open(path + TEST_Q2_DATA, 'wb'), q2_data)
np.save(open(path + TEST_ID_DATA, 'wb'), test_ids)
Shape of question1 data: (2345796, 31)
Shape of question2 data: (2345796, 31)
Shape of Ids: (2345796,)

모델링 (XGboost)

  • 앙상블 부스팅(boosting) 기법 사용
  • 앙상블?
  • 여러 개의 학습 모델을 통해 각각의 결과를 예측하고 (결과값의 평균) 모든 결과를 동등하게 얻는 배깅과 달리
  • 부스팅은 각 결과를 순차적으로 취합하며, 학습 후 잘못 예측한 부분에 가중치를 줘서 다시 학습하는 방식

  • XG부스트: 트리 부스팅 방식+ 경사하강법 (Gradient Descent) + 병렬처리 사용하여 시간 단축할 수 있는 모델
# 저장한 전처리 파일 불러와서 질문 쌍으로 만들기

train_q1_data = np.load(open(path+ TRAIN_Q1_DATA, 'rb'))
train_q2_data = np.load(open(path+ TRAIN_Q2_DATA, 'rb'))
train_labels = np.load(open(path+ TRAIN_LABEL_DATA, 'rb'))

train_input = np.stack((train_q1_data, train_q2_data), axis=1) # [[A], [B]] 형태

print(train_input)
[[[ 1  7  3 ...  0  0  0]
  [ 1 16  3 ...  0  0  0]]

 [[ 3  2  5 ...  0  0  0]
  [ 7  5  7 ...  0  0  0]]

 [[ 1  6 18 ...  0  0  0]
  [ 1 12 18 ...  0  0  0]]

 ...

 [[ 3  0  0 ...  0  0  0]
  [ 3  0  0 ...  0  0  0]]

 [[ 1 19  0 ...  0  0  0]
  [ 1 19 13 ...  0  0  0]]

 [[ 7  0  0 ...  0  0  0]
  [ 2  0  0 ...  0  0  0]]]
print(train_input.shape)
(298526, 2, 31)

다양한 포맷 데이터를 XGboost DMatrix로 변환하기

  • https://xgboost.readthedocs.io/en/stable/python/python_intro.html
!pip install xgboost
Collecting xgboost
  Downloading xgboost-1.6.2-py3-none-win_amd64.whl (125.4 MB)
Requirement already satisfied: scipy in c:\anaconda\lib\site-packages (from xgboost) (1.7.3)
Requirement already satisfied: numpy in c:\anaconda\lib\site-packages (from xgboost) (1.21.5)
Installing collected packages: xgboost
Successfully installed xgboost-1.6.2
train_labels.shape
(298526,)
train_input.shape
(298526, 2, 31)
from sklearn.model_selection import train_test_split

train_input, eval_input, train_label, eval_label = train_test_split(train_input,
                                                                    train_labels, 
                                                                     test_size=0.2, 
                                                                     random_state=4242)
# random_state 값은 추출 데이터 고정, 값은 임의로 지정가능

# sum 함수로 2개의 질문을 1개로 만들어준다.

import xgboost as xgb

train_data = xgb.DMatrix(train_input.sum(axis=1), label=train_label)
eval_data = xgb.DMatrix(eval_input.sum(axis=1), label=eval_label)

data_list = [(train_data, 'train'), (eval_data, 'valid')]
data_list
[(<xgboost.core.DMatrix at 0x24893cb18b0>, 'train'),
 (<xgboost.core.DMatrix at 0x24893cb1d30>, 'valid')]

파라미터 설정

1) 목적함수: 이진 로지스틱 함수
2) 평가지표: Root Means Squared error
3) 학습 반복 횟수: num_boost_round=1000
4) 에러 값이 줄지 않는 동안 N번째 학습 조기 종료가능:early_stopping_rounds

params = {}
params['objective'] = 'binary:logistic'
params['eval_metric'] = 'rmse'

params
{'objective': 'binary:logistic', 'eval_metric': 'rmse'}
num_round = 1000
bst = xgb.train(params, train_data, num_round)
preds = bst.predict(eval_data)
bst
<xgboost.core.Booster at 0x24893b25c40>
bst = xgb.train(params, train_data, num_boost_round = 1000, evals=data_list,
               early_stopping_rounds = 10)
[0]	train-rmse:0.49072	valid-rmse:0.49080
[1]	train-rmse:0.48529	valid-rmse:0.48549
[2]	train-rmse:0.48186	valid-rmse:0.48230
[3]	train-rmse:0.47984	valid-rmse:0.48041
[4]	train-rmse:0.47780	valid-rmse:0.47833
[5]	train-rmse:0.47673	valid-rmse:0.47735
[6]	train-rmse:0.47604	valid-rmse:0.47670
[7]	train-rmse:0.47493	valid-rmse:0.47567
[8]	train-rmse:0.47463	valid-rmse:0.47547
[9]	train-rmse:0.47358	valid-rmse:0.47457
[10]	train-rmse:0.47309	valid-rmse:0.47415
[11]	train-rmse:0.47263	valid-rmse:0.47375
[12]	train-rmse:0.47196	valid-rmse:0.47315
[13]	train-rmse:0.47177	valid-rmse:0.47300
[14]	train-rmse:0.47109	valid-rmse:0.47240
[15]	train-rmse:0.47095	valid-rmse:0.47230
[16]	train-rmse:0.47052	valid-rmse:0.47191
[17]	train-rmse:0.47039	valid-rmse:0.47181
[18]	train-rmse:0.47024	valid-rmse:0.47169
[19]	train-rmse:0.46963	valid-rmse:0.47117
[20]	train-rmse:0.46926	valid-rmse:0.47086
[21]	train-rmse:0.46850	valid-rmse:0.47021
[22]	train-rmse:0.46839	valid-rmse:0.47014
[23]	train-rmse:0.46757	valid-rmse:0.46937
[24]	train-rmse:0.46753	valid-rmse:0.46934
[25]	train-rmse:0.46746	valid-rmse:0.46933
[26]	train-rmse:0.46706	valid-rmse:0.46898
[27]	train-rmse:0.46699	valid-rmse:0.46895
[28]	train-rmse:0.46671	valid-rmse:0.46873
[29]	train-rmse:0.46663	valid-rmse:0.46869
[30]	train-rmse:0.46655	valid-rmse:0.46868
[31]	train-rmse:0.46635	valid-rmse:0.46859
[32]	train-rmse:0.46599	valid-rmse:0.46833
[33]	train-rmse:0.46592	valid-rmse:0.46828
[34]	train-rmse:0.46579	valid-rmse:0.46821
[35]	train-rmse:0.46577	valid-rmse:0.46821
[36]	train-rmse:0.46526	valid-rmse:0.46780
[37]	train-rmse:0.46485	valid-rmse:0.46747
[38]	train-rmse:0.46480	valid-rmse:0.46744
[39]	train-rmse:0.46443	valid-rmse:0.46714
[40]	train-rmse:0.46439	valid-rmse:0.46713
[41]	train-rmse:0.46436	valid-rmse:0.46713
[42]	train-rmse:0.46417	valid-rmse:0.46702
[43]	train-rmse:0.46393	valid-rmse:0.46687
[44]	train-rmse:0.46367	valid-rmse:0.46669
[45]	train-rmse:0.46364	valid-rmse:0.46667
[46]	train-rmse:0.46338	valid-rmse:0.46652
[47]	train-rmse:0.46331	valid-rmse:0.46649
[48]	train-rmse:0.46329	valid-rmse:0.46648
[49]	train-rmse:0.46311	valid-rmse:0.46638
[50]	train-rmse:0.46309	valid-rmse:0.46638
[51]	train-rmse:0.46308	valid-rmse:0.46637
[52]	train-rmse:0.46275	valid-rmse:0.46614
[53]	train-rmse:0.46271	valid-rmse:0.46611
[54]	train-rmse:0.46265	valid-rmse:0.46610
[55]	train-rmse:0.46249	valid-rmse:0.46605
[56]	train-rmse:0.46228	valid-rmse:0.46592
[57]	train-rmse:0.46221	valid-rmse:0.46590
[58]	train-rmse:0.46219	valid-rmse:0.46591
[59]	train-rmse:0.46188	valid-rmse:0.46564
[60]	train-rmse:0.46154	valid-rmse:0.46545
[61]	train-rmse:0.46135	valid-rmse:0.46534
[62]	train-rmse:0.46092	valid-rmse:0.46494
[63]	train-rmse:0.46091	valid-rmse:0.46493
[64]	train-rmse:0.46090	valid-rmse:0.46493
[65]	train-rmse:0.46088	valid-rmse:0.46492
[66]	train-rmse:0.46086	valid-rmse:0.46490
[67]	train-rmse:0.46082	valid-rmse:0.46488
[68]	train-rmse:0.46073	valid-rmse:0.46480
[69]	train-rmse:0.46068	valid-rmse:0.46478
[70]	train-rmse:0.46064	valid-rmse:0.46477
[71]	train-rmse:0.46037	valid-rmse:0.46460
[72]	train-rmse:0.46023	valid-rmse:0.46457
[73]	train-rmse:0.46012	valid-rmse:0.46454
[74]	train-rmse:0.46001	valid-rmse:0.46448
[75]	train-rmse:0.46000	valid-rmse:0.46448
[76]	train-rmse:0.45999	valid-rmse:0.46447
[77]	train-rmse:0.45996	valid-rmse:0.46447
[78]	train-rmse:0.45989	valid-rmse:0.46439
[79]	train-rmse:0.45988	valid-rmse:0.46439
[80]	train-rmse:0.45975	valid-rmse:0.46434
[81]	train-rmse:0.45940	valid-rmse:0.46406
[82]	train-rmse:0.45906	valid-rmse:0.46383
[83]	train-rmse:0.45877	valid-rmse:0.46361
[84]	train-rmse:0.45874	valid-rmse:0.46360
[85]	train-rmse:0.45870	valid-rmse:0.46357
[86]	train-rmse:0.45852	valid-rmse:0.46348
[87]	train-rmse:0.45834	valid-rmse:0.46341
[88]	train-rmse:0.45805	valid-rmse:0.46319
[89]	train-rmse:0.45795	valid-rmse:0.46318
[90]	train-rmse:0.45791	valid-rmse:0.46318
[91]	train-rmse:0.45790	valid-rmse:0.46318
[92]	train-rmse:0.45771	valid-rmse:0.46305
[93]	train-rmse:0.45749	valid-rmse:0.46290
[94]	train-rmse:0.45732	valid-rmse:0.46281
[95]	train-rmse:0.45710	valid-rmse:0.46269
[96]	train-rmse:0.45696	valid-rmse:0.46262
[97]	train-rmse:0.45694	valid-rmse:0.46261
[98]	train-rmse:0.45687	valid-rmse:0.46261
[99]	train-rmse:0.45684	valid-rmse:0.46261
[100]	train-rmse:0.45683	valid-rmse:0.46260
[101]	train-rmse:0.45668	valid-rmse:0.46251
[102]	train-rmse:0.45666	valid-rmse:0.46251
[103]	train-rmse:0.45659	valid-rmse:0.46248
[104]	train-rmse:0.45649	valid-rmse:0.46244
[105]	train-rmse:0.45635	valid-rmse:0.46237
[106]	train-rmse:0.45615	valid-rmse:0.46228
[107]	train-rmse:0.45606	valid-rmse:0.46227
[108]	train-rmse:0.45594	valid-rmse:0.46222
[109]	train-rmse:0.45587	valid-rmse:0.46221
[110]	train-rmse:0.45584	valid-rmse:0.46221
[111]	train-rmse:0.45574	valid-rmse:0.46218
[112]	train-rmse:0.45556	valid-rmse:0.46211
[113]	train-rmse:0.45545	valid-rmse:0.46206
[114]	train-rmse:0.45531	valid-rmse:0.46202
[115]	train-rmse:0.45512	valid-rmse:0.46189
[116]	train-rmse:0.45503	valid-rmse:0.46189
[117]	train-rmse:0.45498	valid-rmse:0.46190
[118]	train-rmse:0.45484	valid-rmse:0.46180
[119]	train-rmse:0.45476	valid-rmse:0.46177
[120]	train-rmse:0.45475	valid-rmse:0.46177
[121]	train-rmse:0.45467	valid-rmse:0.46173
[122]	train-rmse:0.45466	valid-rmse:0.46172
[123]	train-rmse:0.45462	valid-rmse:0.46171
[124]	train-rmse:0.45445	valid-rmse:0.46163
[125]	train-rmse:0.45444	valid-rmse:0.46163
[126]	train-rmse:0.45442	valid-rmse:0.46163
[127]	train-rmse:0.45436	valid-rmse:0.46161
[128]	train-rmse:0.45433	valid-rmse:0.46161
[129]	train-rmse:0.45431	valid-rmse:0.46160
[130]	train-rmse:0.45425	valid-rmse:0.46159
[131]	train-rmse:0.45419	valid-rmse:0.46159
[132]	train-rmse:0.45417	valid-rmse:0.46159
[133]	train-rmse:0.45416	valid-rmse:0.46158
[134]	train-rmse:0.45415	valid-rmse:0.46158
[135]	train-rmse:0.45415	valid-rmse:0.46158
[136]	train-rmse:0.45414	valid-rmse:0.46158
[137]	train-rmse:0.45414	valid-rmse:0.46158
[138]	train-rmse:0.45409	valid-rmse:0.46158
[139]	train-rmse:0.45408	valid-rmse:0.46157
[140]	train-rmse:0.45405	valid-rmse:0.46156
[141]	train-rmse:0.45390	valid-rmse:0.46149
[142]	train-rmse:0.45389	valid-rmse:0.46149
[143]	train-rmse:0.45370	valid-rmse:0.46135
[144]	train-rmse:0.45356	valid-rmse:0.46125
[145]	train-rmse:0.45346	valid-rmse:0.46119
[146]	train-rmse:0.45334	valid-rmse:0.46112
[147]	train-rmse:0.45314	valid-rmse:0.46102
[148]	train-rmse:0.45314	valid-rmse:0.46102
[149]	train-rmse:0.45305	valid-rmse:0.46102
[150]	train-rmse:0.45293	valid-rmse:0.46098
[151]	train-rmse:0.45278	valid-rmse:0.46092
[152]	train-rmse:0.45270	valid-rmse:0.46091
[153]	train-rmse:0.45267	valid-rmse:0.46091
[154]	train-rmse:0.45265	valid-rmse:0.46091
[155]	train-rmse:0.45260	valid-rmse:0.46090
[156]	train-rmse:0.45253	valid-rmse:0.46087
[157]	train-rmse:0.45238	valid-rmse:0.46081
[158]	train-rmse:0.45232	valid-rmse:0.46078
[159]	train-rmse:0.45215	valid-rmse:0.46070
[160]	train-rmse:0.45213	valid-rmse:0.46069
[161]	train-rmse:0.45193	valid-rmse:0.46056
[162]	train-rmse:0.45174	valid-rmse:0.46051
[163]	train-rmse:0.45164	valid-rmse:0.46047
[164]	train-rmse:0.45155	valid-rmse:0.46046
[165]	train-rmse:0.45152	valid-rmse:0.46045
[166]	train-rmse:0.45150	valid-rmse:0.46044
[167]	train-rmse:0.45150	valid-rmse:0.46044
[168]	train-rmse:0.45148	valid-rmse:0.46044
[169]	train-rmse:0.45145	valid-rmse:0.46043
[170]	train-rmse:0.45135	valid-rmse:0.46038
[171]	train-rmse:0.45124	valid-rmse:0.46034
[172]	train-rmse:0.45113	valid-rmse:0.46024
[173]	train-rmse:0.45102	valid-rmse:0.46017
[174]	train-rmse:0.45098	valid-rmse:0.46016
[175]	train-rmse:0.45086	valid-rmse:0.46009
[176]	train-rmse:0.45084	valid-rmse:0.46009
[177]	train-rmse:0.45073	valid-rmse:0.46006
[178]	train-rmse:0.45069	valid-rmse:0.46006
[179]	train-rmse:0.45065	valid-rmse:0.46004
[180]	train-rmse:0.45062	valid-rmse:0.46004
[181]	train-rmse:0.45061	valid-rmse:0.46004
[182]	train-rmse:0.45058	valid-rmse:0.46004
[183]	train-rmse:0.45057	valid-rmse:0.46003
[184]	train-rmse:0.45046	valid-rmse:0.45996
[185]	train-rmse:0.45037	valid-rmse:0.45991
[186]	train-rmse:0.45031	valid-rmse:0.45991
[187]	train-rmse:0.45028	valid-rmse:0.45990
[188]	train-rmse:0.45027	valid-rmse:0.45990
[189]	train-rmse:0.45027	valid-rmse:0.45989
[190]	train-rmse:0.45026	valid-rmse:0.45989
[191]	train-rmse:0.45019	valid-rmse:0.45989
[192]	train-rmse:0.45012	valid-rmse:0.45988
[193]	train-rmse:0.45011	valid-rmse:0.45987
[194]	train-rmse:0.45009	valid-rmse:0.45988
[195]	train-rmse:0.45009	valid-rmse:0.45988
[196]	train-rmse:0.45009	valid-rmse:0.45988
[197]	train-rmse:0.44999	valid-rmse:0.45985
[198]	train-rmse:0.44998	valid-rmse:0.45983
[199]	train-rmse:0.44996	valid-rmse:0.45983
[200]	train-rmse:0.44985	valid-rmse:0.45975
[201]	train-rmse:0.44977	valid-rmse:0.45972
[202]	train-rmse:0.44971	valid-rmse:0.45969
[203]	train-rmse:0.44969	valid-rmse:0.45969
[204]	train-rmse:0.44953	valid-rmse:0.45962
[205]	train-rmse:0.44941	valid-rmse:0.45957
[206]	train-rmse:0.44927	valid-rmse:0.45953
[207]	train-rmse:0.44927	valid-rmse:0.45953
[208]	train-rmse:0.44926	valid-rmse:0.45953
[209]	train-rmse:0.44919	valid-rmse:0.45952
[210]	train-rmse:0.44914	valid-rmse:0.45951
[211]	train-rmse:0.44912	valid-rmse:0.45950
[212]	train-rmse:0.44897	valid-rmse:0.45939
[213]	train-rmse:0.44892	valid-rmse:0.45938
[214]	train-rmse:0.44891	valid-rmse:0.45938
[215]	train-rmse:0.44891	valid-rmse:0.45938
[216]	train-rmse:0.44889	valid-rmse:0.45937
[217]	train-rmse:0.44884	valid-rmse:0.45937
[218]	train-rmse:0.44883	valid-rmse:0.45938
[219]	train-rmse:0.44882	valid-rmse:0.45937
[220]	train-rmse:0.44877	valid-rmse:0.45936
[221]	train-rmse:0.44872	valid-rmse:0.45934
[222]	train-rmse:0.44872	valid-rmse:0.45934
[223]	train-rmse:0.44871	valid-rmse:0.45934
[224]	train-rmse:0.44866	valid-rmse:0.45933
[225]	train-rmse:0.44866	valid-rmse:0.45933
[226]	train-rmse:0.44866	valid-rmse:0.45933
[227]	train-rmse:0.44866	valid-rmse:0.45932
[228]	train-rmse:0.44865	valid-rmse:0.45933
[229]	train-rmse:0.44865	valid-rmse:0.45932
[230]	train-rmse:0.44862	valid-rmse:0.45932
[231]	train-rmse:0.44861	valid-rmse:0.45932
[232]	train-rmse:0.44861	valid-rmse:0.45931
[233]	train-rmse:0.44860	valid-rmse:0.45931
[234]	train-rmse:0.44858	valid-rmse:0.45930
[235]	train-rmse:0.44853	valid-rmse:0.45929
[236]	train-rmse:0.44850	valid-rmse:0.45929
[237]	train-rmse:0.44850	valid-rmse:0.45929
[238]	train-rmse:0.44850	valid-rmse:0.45930
[239]	train-rmse:0.44848	valid-rmse:0.45929
[240]	train-rmse:0.44846	valid-rmse:0.45929
[241]	train-rmse:0.44842	valid-rmse:0.45929
[242]	train-rmse:0.44833	valid-rmse:0.45926
[243]	train-rmse:0.44831	valid-rmse:0.45927
[244]	train-rmse:0.44830	valid-rmse:0.45926
[245]	train-rmse:0.44830	valid-rmse:0.45926
[246]	train-rmse:0.44818	valid-rmse:0.45918
[247]	train-rmse:0.44814	valid-rmse:0.45918
[248]	train-rmse:0.44814	valid-rmse:0.45917
[249]	train-rmse:0.44813	valid-rmse:0.45917
[250]	train-rmse:0.44812	valid-rmse:0.45917
[251]	train-rmse:0.44812	valid-rmse:0.45917
[252]	train-rmse:0.44808	valid-rmse:0.45918
[253]	train-rmse:0.44807	valid-rmse:0.45918
[254]	train-rmse:0.44805	valid-rmse:0.45918
[255]	train-rmse:0.44803	valid-rmse:0.45917
[256]	train-rmse:0.44803	valid-rmse:0.45917
[257]	train-rmse:0.44797	valid-rmse:0.45916
[258]	train-rmse:0.44792	valid-rmse:0.45917
[259]	train-rmse:0.44791	valid-rmse:0.45917
[260]	train-rmse:0.44790	valid-rmse:0.45916
[261]	train-rmse:0.44789	valid-rmse:0.45916
[262]	train-rmse:0.44787	valid-rmse:0.45914
[263]	train-rmse:0.44785	valid-rmse:0.45913
[264]	train-rmse:0.44785	valid-rmse:0.45913
[265]	train-rmse:0.44783	valid-rmse:0.45912
[266]	train-rmse:0.44773	valid-rmse:0.45909
[267]	train-rmse:0.44764	valid-rmse:0.45906
[268]	train-rmse:0.44754	valid-rmse:0.45901
[269]	train-rmse:0.44743	valid-rmse:0.45899
[270]	train-rmse:0.44737	valid-rmse:0.45899
[271]	train-rmse:0.44728	valid-rmse:0.45896
[272]	train-rmse:0.44727	valid-rmse:0.45896
[273]	train-rmse:0.44725	valid-rmse:0.45896
[274]	train-rmse:0.44723	valid-rmse:0.45896
[275]	train-rmse:0.44723	valid-rmse:0.45896
[276]	train-rmse:0.44720	valid-rmse:0.45896
[277]	train-rmse:0.44720	valid-rmse:0.45895
[278]	train-rmse:0.44713	valid-rmse:0.45894
[279]	train-rmse:0.44713	valid-rmse:0.45893
[280]	train-rmse:0.44701	valid-rmse:0.45888
[281]	train-rmse:0.44696	valid-rmse:0.45885
[282]	train-rmse:0.44689	valid-rmse:0.45884
[283]	train-rmse:0.44681	valid-rmse:0.45884
[284]	train-rmse:0.44675	valid-rmse:0.45883
[285]	train-rmse:0.44671	valid-rmse:0.45882
[286]	train-rmse:0.44656	valid-rmse:0.45877
[287]	train-rmse:0.44652	valid-rmse:0.45877
[288]	train-rmse:0.44644	valid-rmse:0.45876
[289]	train-rmse:0.44643	valid-rmse:0.45876
[290]	train-rmse:0.44628	valid-rmse:0.45872
[291]	train-rmse:0.44618	valid-rmse:0.45867
[292]	train-rmse:0.44609	valid-rmse:0.45865
[293]	train-rmse:0.44603	valid-rmse:0.45863
[294]	train-rmse:0.44588	valid-rmse:0.45853
[295]	train-rmse:0.44582	valid-rmse:0.45849
[296]	train-rmse:0.44575	valid-rmse:0.45846
[297]	train-rmse:0.44572	valid-rmse:0.45846
[298]	train-rmse:0.44563	valid-rmse:0.45842
[299]	train-rmse:0.44557	valid-rmse:0.45840
[300]	train-rmse:0.44551	valid-rmse:0.45839
[301]	train-rmse:0.44551	valid-rmse:0.45839
[302]	train-rmse:0.44549	valid-rmse:0.45838
[303]	train-rmse:0.44543	valid-rmse:0.45837
[304]	train-rmse:0.44539	valid-rmse:0.45835
[305]	train-rmse:0.44536	valid-rmse:0.45835
[306]	train-rmse:0.44530	valid-rmse:0.45833
[307]	train-rmse:0.44525	valid-rmse:0.45831
[308]	train-rmse:0.44515	valid-rmse:0.45827
[309]	train-rmse:0.44512	valid-rmse:0.45826
[310]	train-rmse:0.44512	valid-rmse:0.45826
[311]	train-rmse:0.44493	valid-rmse:0.45811
[312]	train-rmse:0.44485	valid-rmse:0.45808
[313]	train-rmse:0.44474	valid-rmse:0.45805
[314]	train-rmse:0.44463	valid-rmse:0.45802
[315]	train-rmse:0.44447	valid-rmse:0.45791
[316]	train-rmse:0.44441	valid-rmse:0.45789
[317]	train-rmse:0.44436	valid-rmse:0.45789
[318]	train-rmse:0.44434	valid-rmse:0.45789
[319]	train-rmse:0.44434	valid-rmse:0.45789
[320]	train-rmse:0.44433	valid-rmse:0.45789
[321]	train-rmse:0.44432	valid-rmse:0.45790
[322]	train-rmse:0.44432	valid-rmse:0.45790
[323]	train-rmse:0.44425	valid-rmse:0.45789
[324]	train-rmse:0.44422	valid-rmse:0.45789
[325]	train-rmse:0.44420	valid-rmse:0.45789
[326]	train-rmse:0.44419	valid-rmse:0.45789

평가 데이터로 예측하기

TEST_Q1_DATA = 'q1_test.npy'
TEST_Q2_DATA = 'q2_test.npy'
TEST_ID_DATA = 'test_ids.npy'


test_q1_data = np.load(open(path+ TEST_Q1_DATA, 'rb'))
test_q2_data = np.load(open(path+ TEST_Q2_DATA, 'rb'))
test_id_data = np.load(open(path+ TEST_ID_DATA, 'rb'))

test_input = np.stack((test_q1_data, test_q2_data), axis=1) # [[A], [B]] 형태
test_data = xgb.DMatrix(test_input.sum(axis=1))
test_predict = bst.predict(test_data)

# 새로운 저장 경로 만들고 싶을 때
# if not os.path.exists(path):
#     os.makedirs()

# 결과 저장

output = pd.DataFrame({'test_id': test_id_data, 'is_duplicate': test_predict})
output.to_csv(path + 'simple_xgb.csv', index=False)

output.head(3)
test_id is_duplicate
0 0 0.211880
1 1 0.109923
2 2 0.399594
output['is_duplicate'].max()
0.9999127388000488
output['is_duplicate'].value_counts()
0.603274    211672
0.567154     81359
0.605096     46305
0.609642     42195
0.404560     37942
             ...  
0.356317         1
0.519661         1
0.556234         1
0.219877         1
0.235535         1
Name: is_duplicate, Length: 491424, dtype: int64

댓글남기기