이번 포스팅에서는 딥러닝 모델의 연속 학습(Continual Learning)에 대하여 다뤄보고자 한다.
연속 학습(Continual Learning)이란?
인공지능의 연속학습(Continual Learning)이란 딥러닝 모델이 새로운 데이터를 기반으로 지속적으로 학습하는 방법을 말한다. 일반적인 딥러닝 모델은 큰 규모의 데이터셋으로 학습하고, 해당 데이터셋을 기반으로 일반화된 패턴을 학습한다. 그러나 실제 환경에서는 새로운 데이터가 지속적으로 발생하고, 이는 기존의 데이터와는 다를 가능성이 높다. 연속학습은 이러한 기존 딥러닝의 문제점을 보완하고자 새로운 데이터에 대해 지속적으로 학습하고, 학습한 지식을 점진적으로 확장해나간다.
이 사진에서 보다시피 데이터는 연구의 방향이나 시장의 수요에 따라서 클래스가 분화된다. 클래스가 점차 분화되고 새로운 클래스가 등장함에 따라 새로운 데이터에 대해 지속적으로 학습할 수 있는 인공지능 모델이 필요해지게 된 것이다. 새로운 데이터가 나올 때마다 처음부터 다시 훈련시키는 것 보다는, 이미 학습된 모델에 새로운 데이터만 추가하는게 더 효율적이고 경제적이라고 할 수 있기 때문이다.
연속 학습의 가장 큰 문제점 : Catastrophic Forgetting
이 연속학습의 가장 큰 문제점은 바로 이 치명적 망각(Catastrophic Forgetting)이다. 이를 이해하기 위해서 다음 그림을 살펴보자.
위 사진은 바로 대표적인 데이터셋 중 하나인 MNIST 데이터셋을 5개의 task로 나눈 것이다. 첫 번째 task1에서는 0과 1을 구분하도록 모델을 학습시킨다. 그 다음, 이 학습된 모델에 2와 3을 구분하도록 또 다시 학습시킨다. 이런 방식으로 앞에서 학습된 모델에 연속적으로 새로운 task를 훈련시키면 가장 큰 문제점은 앞에서 훈련되었던 내용을 잊어버린다는 것이다. 8과 9를 구분하는 마지막 task5를 훈련시키고 나면, 가장 처음에 훈련되었던 0과 1을 분류하는 task1은 정확도가 상당히 낮아진다. 즉, 점진적으로 모델을 학습해감에 따라 앞에서 학습되었던 내용을 점차 잊어버리는 것이다. 이를 연속학습에서는 치명적 망각(Catastrophic forgetting)이라고 부르며, 이를 해결하는 것이 연속학습의 핵심이라고 할 수 있다.
연속 학습의 대표적인 방법
연속 학습 방법은 크게 Replay method, Regularization-based method, Parameter isolation method로 분류할 수 있다.
1. Replay method
Replay 방법은 이전 작업의 데이터를 저장하고, 새로운 task에 대한 학습에서 이전 데이터를 다시 사용하는 방식이다. 이전 데이터를 다시 사용하는 방식은 버퍼에 일부 데이터를 저장(real data replay)하거나 생성모델을 통해 데이터를 생성(pseudo replay)하여 사용하는 방식이 있다. 이 방식은 연속학습의 문제점이었던 Catastrophic Forgetting을 어느 정도는 해결할 수 있으나 이전의 데이터를 메모리에 넣어 주어야 하기 때문에 많은 양의 메모리 공간이 요구된다는 점, 실제로 순차적 학습을 할 때 다시 과거 데이터에 접근할 수 있다는 보장이 없다는 점에서 여전히 문제점이 존재한다.
Generative Replay(2017)
대표적인 Replay 방식으로는 Generative Replay가 있다. Generative Replay는 새로운 task를 학습할 때 생성모델을 통해 데이터를 생성하여 이전 데이터를 다시 사용하는 방식을 택한다.
Generative Replay는 Generator와 Solver의 구조를 가지고 있다. Generator는 이전에 학습했던 데이터를 재생산하고 Solver는 이 재생산된 데이터와 진짜 데이터를 이용해 분류하는 역할을 하도록 학습된다. 이런 구조 전체를 Scholar라고 하고 각 task마다 Scholar의 구조를 순차적으로 학습한다.
2. Regularization-based method
다음으로는 Regularization-based 방식이 있다. 이 방식은 모델의 파라미터를 조절하여 이전 작업에 대한 지식을 보존하는 방식이다. 즉, 모델의 성능에 영향을 주는 활성화 함수(activation function), 옵티마이저 (Optimizer), 학습률 (Learning Rate) 등의 파라미터의 impact를 계산하여 이전 작업의 중요한 impact를 주는 파라미터는 이후의 학습에서도 보호하겠다는 것이다.
LwF(Learning Without Forgetting, 2017)
Regularization-based method의 대표적인 모델로는 LwF(Learning Without Forgetting, 2017)가 있다. LwF는 각 stage의 학습을 시작하기 전에 현재 stage의 모든 데이터에 대해 이전 stage에서 학습이 완료된 모델의 feed-forward logit(LwF-logit)을 미리 계산하고, 각 데이터의 label과 LwF-logit을 이번 stage 학습에 활용한다. Label은 새로운 학습을 위해 사용되고, LwF-logit은 과거의 데이터를 보존하는데 사용된다.
위 그림은 다른 multi-task learning method(b, c, d)와 LwF(e)를 비교해놓은 그림이다. 이 그림에서는 각 부분의 색깔을 잘 확인할 필요가 있다. 주황색은 random initialize + train, 하늘색은 fine-tune, 흰색은 unchanged를 의미한다. 따라서 무엇이 학습되었는지, 학습이 되지 않았는지를 구분할 수 있다. 하나씩 살펴보자.
- (b) Fine-tuning : 네트워크의 backbone과 새로운 task에 대하여 학습시키는 방식이다. 이 방식의 문제점은 이전에 학습된 파라미터(그림의 하얀색 부분)의 guidance 없이 공유된 파라미터들을 업데이트 시키기 때문에 이전에 학습된 task의 성능을 저하시킨다는 것이다.
- (c) Feature Extraction : Fine-tuning과는 달리, 사전 학습된 네트워크의 backbone을 다시 학습하는 대신에, 새로운 task에 더 많은 layer을 추가하고 이 branch만 학습한다. backbone과 공유된 파라미터를 업데이트 시키지 않기 때문에, 새로운 task의 특징을 잘 표현하기에 성능이 떨어지는 경우가 많다.
- (d) Joint Training : 이 방식은 새로운 task를 위한 새로운 브랜치를 추가하고 전체 네트워크를 다시 학습한다. 다시 말해서, 모든 task의 데이터를 동시에 함께 사용하여 학습한다는 것이다. 정확성 측면에서는 가장 효율적인 방법으로 보일 수 있다. 하지만 더 많은 task가 추가되면 점점 더 학습하는 데 번거로울 수 있으며, 이전 작업의 학습 데이터를 사용할 수 없는 상황에서는 적절하지 않을 수 있다는 문제점이 있다.
- (e) LwF : LwF 방식은 새로운 task를 위하여 branch를 추가하지만, 새로운 task에 대한 학습을 진행할 때에는 이전 task의 데이터를 사용하지 않는다. 대신 Knowledge Distillation(지식 증류) 기법을 사용하는데, 지식 증류는 큰 규모의 미리 trained된 모델(선생님)로부터 작은 모델(학생)로 지식을 전달하는 기법이다. LwF에서는 이전 task에서 학습된 모델을 "선생님"으로 정의하고, 새로운 task를 위해 학습 중인 모델을 "학생"으로 정의한다.
3. Parameter isolation method
이 방식은 용어 그대로 task 간의 파라미터를 분리하여 각 task의 파라미터를 독립적으로 학습하는 방식이다. 이를 통하여 task 간의 간섭을 줄이고 각 task에 대한 성능을 개별적으로 관리할 수 있다.
PNN(Progressive Neural Network, 2016)
첫 번째 column은 task1, 두 번째 column은 task 2에 학습되었다. 세 번째 column은 마지막 task를 위하여 이전에 학습된 모든 특성에 접근할 수 있도록 추가된 column이다. 이 모델은 task가 늘어날 때마다 column을 추가한다. 이전의 column들은 새로운 데이터에 학습되지 않고 고정된다. 이렇게 되면 새로운 task가 추가되더라도 이전의 파라미터는 변경되지 않으며 독립적으로 학습된다는 것이 특징이다.