AI/LLM

Titans: Learning to Memorize at Test Time

Tech리 2025. 2. 12. 20:32

현재 계속 세상을 놀라게 하고 있는 GPT, Claude 등의 LLM 들은 많은 장점을 가지고 있지만, 그렇다고 완벽하다고 보기는 아직 어렵습니다. LLM의 단점 중 하나로 입력되는 Context의 길이가 증가할수록 이를 처리하기 위한 계산 및 메모리 비용은 기하급수적으로 증가한다는 문제가 있습니다.

 

그렇기에, Mamba 와 같은 모델이 주목 받기도 했는데요.(Mamba에 대해서는 따로 다루도록 하겠습니다.) 일부에서 Mamba와 같은 모델들이 문맥을 압축하는 과정에 중요한 세부 정보를 놓치는 문제가 있다고 지적하기도 합니다.

 

그래서 최근 Google 에서는 '트랜스포머' 기반 LLM에 ‘신경 기억(neural memory)' 레이어를 추가, 모델이 단기와 장기 기억 작업을 모두 효율적으로 처리할 수 있도록 설계된 새로운 신경망 아키텍처 ‘타이탄(Titans)’에 관한 논문을 발표했습니다. 

 

해당 논문에서는 현재의 LLM이 가진 문제를 다음과 같이 서술하고 있습니다. 

  • 순환 모델(Recurrent Models): 데이터를 고정 크기의 메모리(히든 상태)로 압축하여 저장하려고 합니다. 그러나 이 방식은 모든 데이터를 효과적으로 기억하기 어렵습니다.
  • 주의 메커니즘(Attention): 전체 문맥(context) 창을 참조하며 모든 토큰 간의 직접적인 의존성을 잘 모델링합니다. 하지만 이 과정은 계산 비용이 매우 높아(이차적 복잡도) 문맥 창 크기를 고정해야 하는 한계가 있습니다.

특히, 언어 모델링, 비디오 이해, 장기 시계열 예측 등과 같은 복잡한 실제 작업에서는 문맥 윈도우가 매우 커질 수 있어 트랜스포머의 적용에 한계가 있습니다. 이를 해결하기 위해서 softmax 대신 커널 함수를 어텐션에 적용한 다양한 Linear Transformer의 변형들을 제안하며 메모리 사용량을 크게 줄였습니다. 그러나 효율성과 긴 문맥 확장 능력에도 불구하고, 선형 트랜스포머는 커널 기법으로 인해 데이터를 행렬 형태의 상태로 압축하는 선형 순환 네트워크가 되어 전통적인 트랜스포머보다 성능 면에서 뒤처집니다 즉, 선형 모델은 확장성과 효율성(선형 대 제곱 복잡도)을 높여 긴 문맥 처리에 유리하지만, 매우 긴 문맥을 작은 벡터나 행렬 상태에 압축하는 데 한계가 있습니다.

 

그래서 이를 개선하고자 Titan 논문에서는 다음 다섯 가지 질문을 해결하고자 했습니다. 

 

  • 효과적인 기억 구조는 무엇인가?
  • 적절한 기억 업데이트 메커니즘은 무엇인가?
  • 효율적인 기억 검색 방법은 무엇인가?
  • 서로 연결된 다양한 기억 모듈을 효과적으로 통합하려면 어떻게 해야 하는가?
  • 긴 과거를 저장하기 위해 깊은(딥) 기억 모듈이 필요한가?

 

이를 위해 본 연구는 테스트 시 데이터를 자신의 파라미터에 저장하는 장기 신경 기억 모듈을 설계하고, 이를 Titans라는 새로운 딥러닝 아키텍처에 통합합니다. Titans는 단기 기억, 장기 기억, 그리고 작업 지식을 인코딩하는 지속 기억으로 구성되며, 다양한 작업(언어 모델링, 상식 추론, 시계열 예측, DNA 모델링 등)에서 기존 모델들보다 우수한 성능과 뛰어난 확장성을 입증합니다.

 

1) 장기 기억 메모리

모델이 테스트 시점에서 직접 데이터를 기억하고 잊으며 필요할 때 정보를 검색할 수 있도록 하는 장기 신경 기억(Long term neural memory) 모듈을 제안하는데 이때 단순히 학습데이터 자체를 암기하는 것이 아니라 테스트 상황에서 새로운 데이터에 대해 어떻게 기억하고 필요 없는 정보를 잊어버릴지 스스로 조절하는 메타-모델을 구축하는 것이 핵심입니다. 그래서 아래와 같은 아이디어 등을 사용했습니다. 

 

(1) 서프라이즈 메트릭 

인간의 경우, 예상과 다른(놀라운) 사건이 더 강하게 기억되는 것에서 착안하여, 입력에 대한 기울기(gradient)를 서프라이즈로 정의하는데 기울기가 클수록 이전 데이터와 크게 다르므로 더 기억할 가치가 있다고 판단합니다. 이 때 기본 업데이트 방식은 아래와 같이 정의되는데, (l은 메모리 모듈이 목표로 하는 손실 함수)

기본 업데이트 방식은 한 번의 놀람 이후 기울기가 작아져서 이후 정보가 문제를 해결하기 위해 최근의 놀람 정보를 유지하는 역할인 과거 서프라이즈 (Past Surprise) 및 현재 입력에 대한 놀람 정도를 반영한 순간 서프라이즈 (Momentary Surprise)을 반영하고 이를 모멘텀 개념과 결합하여 아래와 같이 업데이트 합니다. 

 

  • ηt\eta_t: 데이터에 따라 이전 서프라이즈의 영향을 얼마나 유지할지를 조절하는 감쇠(decay) 계수
  • θt\theta_t: 현재 입력의 서프라이즈를 얼마나 반영할지 결정하는 계수

 

(2) 목표 함수 및 키-값 연관 학습

  • 입력 xtx_t에 대해 두 개의 선형 층을 사용하여 키 kt=xtWKk_t = x_t W_K와 값 vt=xtWVv_t = x_t W_V를 생성합니다.
  • 손실 함수는 메모리 Mt−1M_{t-1}가 키 ktk_t에 대해 예측한 값과 실제 값 vtv_t 간의 차이를 L2 노름으로 측정합니다. ℓ(Mt−1;xt)=∥Mt−1(kt)−vt∥22\ell(M_{t-1}; x_t) = \| M_{t-1}(k_t) - v_t \|_2^2
  • 이 inner-loop 최적화를 통해, 메모리 모듈은 테스트 시점에서 키와 값 사이의 연관성을 학습하게 됩니다.

(3) 망각(Forgetting) 메커니즘

  • 매우 긴 시퀀스(수백만 토큰 등)에서 모든 정보를 유지하는 것은 메모리 한계와 성능 저하를 야기할 수 있으므로, 불필요한 정보를 선택적으로 잊어버리는 기작이 필요합니다.
  • 이를 위해, 업데이트 식에 forgetting gate 역할을 하는 계수 αt\alpha_t를 도입하여, Mt=(1−αt)Mt−1+StM_t = (1 - \alpha_t) M_{t-1} + S_t
    • αt\alpha_t 값에 따라 이전 메모리의 정보 일부를 제거하거나, 상황에 따라 전체를 잊어버릴 수도 있습니다.

(4) 메모리 아키텍처 및 정보 검색

  • 메모리 구조:
    • 본 논문에서는 간단한 MLP(다층 퍼셉트론) 구조(최소 LM≥1L_M \ge 1층)를 사용하지만, 깊은(딥) 메모리 모듈(최소 2층 이상)이 선형 모델보다 표현력이 뛰어나다는 이론적 근거가 있습니다.
  • 정보 검색:
    • 테스트 시, 메모리에서 정보를 검색할 때는 추가적인 가중치 업데이트 없이 단순한 순전파 과정을 통해, 입력 xtx_t를 선형 층 WQW_Q로 쿼리 qtq_t로 변환한 후, yt=M∗(qt)y_t = M^*(q_t) 를 통해 관련 정보를 추출합니다.

 

2) Parallelize the Long-term Memory Training

 

위에서 제안한 장기 메모리 모듈의 경우 gradient descent(모멘텀과 weight decay 포함)를 통해 연관 기억 손실(associative memory loss)을 최적화하는 메타 모델과 동일합니다. 이 때 이론상 전체 시퀀스 길이 NN에 대해 O(N)O(N) FLOPs가 필요하지만, 실제 하드웨어(TPU, GPU 등)의 이점을 살리기 위해 병렬화와 텐서화가 요구됩니다.

 

 

(1) 미니배치 기반 업데이트 재구성:

  • 시퀀스를 크기 b의 청크(chunk)로 나누어 미니배치 gradient descent로 업데이트를 수행합니다.
  • 업데이트 식은 Mt=(1−αt)Mt−1−θt∇ℓ(Mt−1;xt)M_t = (1 - \alpha_t) M_{t-1} - \theta_t \nabla \ell(M_{t-1}; x_t) 를 재구성하여, 누적 곱(βt\beta_t)과 합(sum)을 이용한 표현으로 바꿀 수 있습니다.
  • 손실 함수의 기울기(gradient)는 여러 토큰에 대해 동시에 계산할 수 있으며, 이를 통해 업데이트를 행렬 곱셈(matmul)과 덧셈만으로 구현할 수 있습니다.

(2) 모멘텀 항의 병렬 계산:

  • 모멘텀 항 St=ηtSt−1−θtutS_t = \eta_t S_{t-1} - \theta_t u_t (여기서 ut=∇ℓ(Mt′;xt)u_t = \nabla \ell(M_{t'}; x_t)) 역시 선형 순환(recurrence) 형태로 나타나며,
  • 병렬 연관 스캔(parallel associative scan) 기법을 사용하면, 각 청크 내에서 모든 StS_t를 동시에 계산할 수 있습니다.

(3) 청크 단위의 파라미터 설정:

  • 원래 각 토큰마다 입력 의존적인 αt\alpha_t, θt\theta_t, ηt\eta_t 값을 사용하지만, 이를 각 청크 단위의 상수로 설정할 수도 있습니다.
  • 이렇게 하면 표현력은 다소 줄어들지만, 저장해야 하는 파라미터가 감소되고 전체 업데이트가 선형 시불변 시스템(LTI) 형태가 되어 전역 합성곱(global convolution)으로도 계산할 수 있어 학습 속도를 더욱 향상시킬 수 있습니다.
  • 실험에서는 토큰 단위의 데이터 의존성을 사용하지만, 청크 단위의 단순화는 향후 대형 모델 학습에 유용한 연구 방향이 될 수 있습니다.

(4) MAC (Memory as a Context) 아키텍처

 

위의 내용을 종합해서 아래와 같은 MAC 아키텍처를 구성하는데 MAC 아키텍처는 하드웨어 가속기(TPU, GPU 등)의 병렬 처리 능력을 최대한 활용하면서도, 긴 시퀀스에 대한 메모리 업데이트를 효과적으로 수행할 수 있게 합니다.
또한, MAC 아키텍처는 단순히 위의 병렬화 기법에 국한되지 않고, 코어(Core) 브랜치지속적(Persistent) 메모리 브랜치를 포함하여 인컨텍스트 학습과 작업별 지식 저장 기능을 제공하는 포괄적인 시스템으로 구성되어 있습니다.

 

 

 

3) Persistent Memory

문맥 기반(long-term) 메모리는 입력에 의존하는 반면, persistent memory는 입력과 무관하게 학습 가능한 파라미터 집합으로, 작업 관련 지식(메타 정보)을 저장하는 역할을 합니다. 이를 통해 모델이 단순히 현재 문맥뿐 아니라, 과거에 축적된 작업 수행 방법이나 추상적 지식을 활용할 수 있게 됩니다.

 

(1) 구현 방법

  • 학습 가능한 파라미터 P=[p1,p2,…,pNp]P = [p_1, p_2, \ldots, p_{N_p}]를 정의한 후, 이를 시퀀스의 시작 부분에 **단순 연결(Concatenation)**하여 입력 데이터를 확장합니다.
  • 예를 들어, 원래 입력 xx가 있다면, 새로운 입력은 xnew=[p1,p2,…,pNp]∥xx_{\text{new}} = [p_1, p_2, \ldots, p_{N_p}] \parallel x 와 같이 구성됩니다.

How to Incorporate Memory?

이부분에서는 설계된 신경 기억 모듈(장기 메모리)을 딥러닝 아키텍처에 효과적이고 효율적으로 통합하는 방법을 제시합니다. 기존 Transformer의 키-값 어텐션은 현재 문맥(단기 메모리)만 반영하는 반면, 제안하는 신경 기억 모듈은 데이터로부터 지속적으로 학습하여 장기 메모리 역할을 수행합니다. 

 

(1) Memory as a Context

  • 구조 및 동작 방식:
    • 긴 시퀀스를 일정 크기의 세그먼트로 분할합니다.
    • 현재 세그먼트 S(t)S^{(t)}를 현재 문맥으로 보고, 이전 장기 메모리 상태 Mt−1M_{t-1}를 이용해 관련 정보를 검색합니다.
      • 검색된 정보 ht=Mt−1∗(qt)h_t = M^*_{t-1}(q_t)에서 쿼리 qt=S(t)WQq_t = S^{(t)} W_Q 사용.
    • 검색된 역사적 정보와 학습된 지속적 메모리 파라미터들을 현재 세그먼트에 연결(Concatenation)하여 어텐션 모듈에 전달합니다.
      • S~(t)=[p1,p2,…,pNp]∥ht∥S(t)\tilde{S}^{(t)} = [p_1, p_2, \ldots, p_{N_p}] \parallel h_t \parallel S^{(t)}
    • 어텐션 결과 yty_t를 사용해 다음 세그먼트를 위한 메모리 MtM_t를 업데이트하고 최종 출력을 생성합니다.
  • 장점:
    1. 현재 데이터와 과거 정보를 함께 고려하여, 필요한 경우에만 장기 메모리 정보를 활용할 수 있음.
    2. 어텐션 모듈이 중요한 정보만 선택하여 저장함으로써 메모리 용량을 효율적으로 관리함.
    3. 테스트 시, 지속적 메모리 파라미터는 고정되고, 어텐션은 인컨텍스트 학습을 수행하며, 장기 메모리는 계속 학습(기억)됨.

(2) Gated Memory

  • 구조 및 동작 방식:
    • 두 개의 분기(branch)를 사용합니다.
      • 한 분기에서는 입력 데이터를 그대로 사용해 장기 메모리를 업데이트하고,
      • 다른 분기에서는 sliding window attention (SWA)과 prefix를 사용합니다.
    • 두 분기의 결과는 비선형 게이팅(예: 정규화 후 비선형 활성화)을 통해 결합됩니다.
  • 특징:
    • Sliding window attention은 정밀한 단기(짧은 문맥) 메모리 역할을 수행하고, 신경 기억 모듈은 점차 희미해지는(fading) 장기 메모리 역할을 합니다.
    • 서로 다른 구조의 헤드를 사용하는 다중 헤드 아키텍처 형태로 해석할 수 있습니다.


(3) Memory as a Layer

  • 구조 및 동작 방식:
    • 입력 데이터에 지속적 메모리 파라미터를 앞에 붙인 후, 이를 신경 기억 모듈(Memory as a Layer, MAL)에 전달합니다.
      • x~=[p1,p2,…,pNp]∥x\tilde{x} = [p_1, p_2, \ldots, p_{N_p}] \parallel x
    • MAL의 출력을 다시 sliding window attention에 전달하여 최종 출력을 생성합니다.
  • 단점:
    • 각 층의 성능에 의존하기 때문에, 어텐션과 신경 기억 모듈의 상호 보완적 처리 능력을 온전히 활용하지 못할 수 있음.
  • 추가 변형:
    • Memory Without Attention:
      • 어텐션 없이 신경 기억 모듈 자체만으로 시퀀스 모델링을 수행하는 방식도 제안됩니다.
      • 이는 기억 체계의 각 구성 요소가 독립적으로 작동할 수 있음을 보여줍니다.


(4) Architectural Details 및 추가 사항

  • 구현 세부사항:
    • 모든 블록에 Residual Connection 사용.
    • SiLU 활성화 함수와 ℓ2-노름을 활용해 쿼리와 키를 정규화함.
    • Query, Key, Value 투영 후 1D Depthwise-Separable Convolution을 도입해 성능 개선 및 계산 효율을 도모함.
    • 최종 출력 전 Linear Layer를 이용한 게이팅 및 정규화 적용.
  • Theorem 4.1:
    • 기존 Transformer, 대각선 선형 순환모델, DeltaNet 등이 TC0 (상수 깊이 회로)로 제한되는 반면, Titans는 TC0를 넘어선 문제 해결이 가능하여 이론적으로 더 표현력이 높음을 증명합니다.

장기 신경 기억 모듈을 딥러닝 아키텍처에 통합하는 세 가지 주요 설계(문맥으로서의 메모리, 게이트드 메모리, 계층으로서의 메모리)와 추가 변형(어텐션 없이 동작하는 메모리) 및 구현 세부사항을 소개합니다. 이를 통해 모델이 단기 및 장기 메모리의 장점을 모두 활용하여, 긴 문맥 처리 및 상태 추적 등에서 기존 모델보다 우수한 성능을 발휘할 수 있도록 합니다.

 

Titan 아키텍처는 긴 문맥 처리의 한계를 극복하고, Transformer와 기존 순환 모델을 능가하는 혁신적인 성능을 보여줍니다. MAC, MAG, MAL 등 다양한 변형을 통해, Titan은 모멘텀, weight decay, 깊은 메모리 업데이트 기법을 효과적으로 적용하여 언어 모델링, 시계열 예측, DNA 모델링 등 여러 분야에서 우수한 결과를 기록했습니다. 특히 2M 토큰 이상의 초장기 문맥에서도 안정적인 정확도와 효율적인 메모리 관리를 달성하며, 대규모 데이터 처리의 새로운 가능성을 제시합니다. 이러한 연구 결과는 앞으로 메모리 기반 딥러닝 모델이 인공지능의 다양한 실용적 문제를 해결하는 데 큰 역할을 할 것임을 시사합니다. 여러분도 Titan의 발전과 향후 연구 동향에 주목해 주시기 바라며, 이 글이 이해하는 데 도움이 되었기를 바랍니다.

 
 

 

'AI > LLM' 카테고리의 다른 글

DeepSeek R1  (0) 2025.02.12