본문 바로가기
기계학습/자연어 처리 머신러닝

LSTM 실습 - 로이터 뉴스 분류기

by tryotto 2020. 2. 12.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from tensorflow.keras.datasets import reuters
import numpy as np
import seaborn as sns
 
(X_train, Y_train),(X_test, Y_test) = reuters.load_data(num_words=1000, test_split=0.2)
 
 
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM, Embedding
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.sequence import pad_sequences
 
 
 
# 데이터 전처리 - padding 연산
vocab_size = 1000
max_len = 100
X_train = pad_sequences(X_train, maxlen=max_len)
X_test = pad_sequences(X_test, maxlen=max_len)
 
 
 
# 데이터 전처리 - 라벨을 원핫 인코딩 처리
Y_train = to_categorical(Y_train)
Y_test = to_categorical(Y_test)
 
 
 
# 모델 설계
model = Sequential()
model.add(Embedding(vocab_size, 120))
model.add(LSTM(120))
model.add(Dense(46, activation='softmax'))
 
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
history = model.fit(X_train, Y_train, batch_size=100, epochs=5, validation_data=(X_test, Y_test))
cs



RNN과 마찬가지로, 큰 어려움 없이 학습을 시킬 수 있었다




1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# 직접 테스트 해보기
import random 
 
test_len = len(X_test)
idx = random.randint(0, test_len-1)
 
little_test_X = X_test[idx:idx+1]
little_test_Y = Y_test[idx:idx+1]
 
real_label = 0
tmp_idx = -1
for v in little_test_Y[0]:
  tmp_idx +=1
  if v == 1:
    real_label = tmp_idx
    break
 
# 예측 라벨 얻기
predict_label = model.predict(little_test_X)
 
max_val = -1
max_idx = 0
tmp_idx = -1
for val in predict_label[0]:  
  tmp_idx += 1
  if max_val < val:
    max_val = val
    max_idx = tmp_idx
 
 
# 원문으로 변환하기
word_to_index = reuters.get_word_index()
index_to_word={}
 
for word, index in word_to_index.items():
  index_to_word[index] = word
 
rst_string ="원문 : "
for i in range(max_len):
  tmp_idx = little_test_X[0][i]
  rst_string += (index_to_word[tmp_idx]+" ")
 
print(rst_string)
 
# 라벨 값 출력
print("predicted label : ", max_idx)
print("real label : ", real_label)
 
cs


직접 테스트 할 수 있는 코드도 추가시켜보았다
라벨 값도 원문으로 바꾸고싶은데, 애초에 처음부터 인코딩이 정수로 되어있어서 힘들어 보인다