RNN, Truncated BPTT, Time RNN

2021. 7. 27. 17:28IT Study/Ai

이번에는 자연어 처리에서 자주 사용되는 RNN들을 깊게 탐구했다. 

RNN이란 쉽게말해 순환되는 신경망, 시계열 데이처를 처리하는 신경망등이라고 할 수 있겠다. RNN은 같은 모델을 순환반복하면서 시계열 데이터를 처리하는 모델인 것이다.

이 RNN의 역전파를 끊어서 학습한 것이 Truncated BPTT이고 시계열 데이터를 학습하기 편하게하며 동시에 Truncate BPTT도 같이 이용할 수있도록 RNN을 여러개를 묶어 모델화한 것이 TimeRNN이다.


RNN

 

위에서도 말했듯 RNN은 시계열 데이터를 처리하는 신경망이라고 할 수 있다. 

 

시계열 데이터란 순서,순차가 있는 데이터라고 봐도 무방할 것 같다. 예를 들어 문장 "i eat apple"은 3개의 단어가 순차적으로 연결되어야 그 의미가 성립하는 시계열 데이터라고 할 수있다. 

 

그러면 거두절미하고 바로 RNN의 구조를 확인해 보자.

 

RNN의 입력값은 단어의 분산처리벡터이며 출력값은 총 두갈래로 나뉜다.(분기노드)

두갈래로나뉜 출력값은 같은 값이다. 여기서 2갈래로 나뉘는 걸 보고 역전파때 2개를 더해주는 걸 떠올리면 지금까지 인공지능공부를 잘한걸로 볼 수 있겠다..

 

그럼 RNN의 박스안 수식을 보자

h(t) = tanh( h(t-1)W(h) + x(t)W(x) + b)

t는 시계열데이터의 시각(순서)을 나타낸다. 

이 식에서 나타나는 가중치는 2개이다. 첫번 째는 입력값 합성곱 가중치인 W(x) 두번 째는 바로전 시계열 데이터 출력값인 h(t-1) 합성곱 가중치 W(h)이다. W(x)와W(h)는 완전히 다른 가중치니 헷갈리지 말자. 이 두값과 바이어스 b까지 더한값이 출력값이 된다.

 

RNN역전파

 

RNN의 역전파는 2가지만 생각하면 지금까지의 일반 신경망과 별 다를 것이 없어진다. 

하나는 분기노드의 기울기 덧셈이고 또 하나는 tanh의 역전파이다.

 

바로 역전파의 전체적인 모습을 봐보자.

 

먼저 분기노드에서온 dh(next)를 구할때 각각에서온 기울기를 더해준다.

tanh까지 역전파하게 되면 d(tanh) = dh(next) * (1 - dh(next) ** 2)  

이렇게 되면 덧셈노드 역전파 성질에 의해 db는 d(tanh)가 되고 MatMul계층으로 이동하게 된다.

dW(h) = d(tanh) * h(prev).T

dh(prev) = d(tanh) * W(h).T

dW(x) = d(tanh) * x.T

dx = d(tanh) * Wx.T 까지 도출이 완료된다. -> MatMul의 역전파는 간단히 표현하면 앞 역전파값의 반대노드값 역전파를 곱해준 값이다.

 

 

Truncated BPTT

이전 RNN을 완벽하게 이해 했으면 이 뒤부턴 조금의 응용만 있을 뿐이다.

 

Trnncated BPTT는 RNN의 몇 가지 문제점을 보완하기 위해 고안된 방법이다.

 

RNN은 시계열데이터를 처리하며 이전값을 계속 유지하고 역전파를 할 때 순전파한 모든 계층을 역으로 순회하게 된다.

이 구조는 기울기 소실이라는 문제점과 최적화의 문제가 존재한다.

 

먼저 기울기 소실의 문제는 아마 RNN의 구조를 충분히 이해 했으면 알거라 생각된다. 만약 시계열데이터가 하나의 글이고 글에 포함된 단어의 수가 10000개가 넘어간다 생각해보자. 그렇게되면 역전파시 글의 첫번째 단어의 가까운 단어들은 오차의 대한 기울기가 0으로 수렴하게 될 것이다.

 

최적화의 문제는 RNN을 사용하기위해 TimeRNN을 직접 구현하다보면 순전파 때 사용된 RNN계층들을 차곡차곡 모아두는데 윗 경우를 생각하면 10000개의 RNN계층이 저장되는 매우 비효율적인 메모리 사용이 나오는 것이다.

 

그렇다면 이 문제를 해결하기 위한 방법이 무엇일까?

정답은 RNN의 역전파를 일정단위씩 끊어서 계산하는 방법이다.

그렇게 하기 위해서는 일정단위마다 오차를 다시 계산 해줘야한다. 하지만 이정도는 충분히 감수할만한 범위 내이다. 

물론 순전파는 끊지않고 이어줘야한다.

 

한번 전체적은 Truncated BPTT의 모습을 보자.

위 모델은 10개씩 끊어서 역전파를 해준 모습이다. 

 

Time RNN

 

Time RNN은 Truncated BPTT의 적용을 모토로 RNN의 학습을 용이하게 만든 모델이다. 한 뭉치의 시계열데이터의 크기만큼 RNN계층을 묶어준 것이 TimeRNN의 기본적인 모습이다.

 

한번 전체적인 모습을 보자.

여기서 RNN은 하나의 부품들이다. 이 부품들이 여러개 모이고 동작하는 것이 TimeRNN의 모습인 것이다.

 

TimeRNN은 그모습에서 볼 수 있듯이 입력값과 출력값이 여러개의 데이터로 이루어 진 것을 볼 수 있다. 당연히 여러개의 RNN이 있으면 그 것에 들어갈 여러개의 인풋값이 있을 것이고 그로인해 출력값도 맞춰질 것이다.

(여기서 xs형상은 2차원이고 배치처리까지 고려하면 3차원이된다.)

 

TimeRNN은 순전파시 클래스 내부에 사용된 모델을 차곡차곡 저장해 두게 된다. 그 이유는 역전파 때 역순회를 용이하게 하기위함이다.

실제로 구현부 파트는 순전파때는 RNN모델을 받아와서 forward함수를 실행해주고 자신을 append하는 것이 기본적인 모습이고 역전파 때는 쌓인 모델을 역순회하며 backward함수를 실행해주며 가중치를 갱신하고 dxs를 리턴하는것이 끝이다.

윗 두개 RNN과 Truncated BPTT를 잘 이해했다면 되게 직관적인 로직일 것이다.

 

하나 유의할 점이 있다면 hs와 xs는 여러개의 데이터가 묶인 형태이다.(당연하지만) 그래서 역전파 때 역순회시 하나하나 대응이 잘되게 하면 된다. 특히 미니배치가 포함되면 배열의 인덱스에 주의해야한다.(3차원이 되어 복잡해짐)

 

출처 : 밑바닥부터 시작하는 딥러닝2

'IT Study > Ai' 카테고리의 다른 글

Seq2Seq 개선 (Reverse, Peeky, Attention-어텐션)  (0) 2021.08.14
seq2seq(Sequence to Sequence)  (0) 2021.08.04
LSTM(Long Short-Term Memory)  (0) 2021.07.30
RNNLM(RNN Language Model)  (0) 2021.07.28
자연어처리 Word2Vec 개념  (0) 2021.07.18