ํ‹ฐ์Šคํ† ๋ฆฌ ๋ทฐ


๐Ÿ“š ๋…ผ๋ฌธ : "Attention is all you need"

๐Ÿ’ปGithub : https://github.com/jadore801120/attention-is-all-you-need-pytorch

 

GitHub - jadore801120/attention-is-all-you-need-pytorch: A PyTorch implementation of the Transformer model in "Attention is All

A PyTorch implementation of the Transformer model in "Attention is All You Need". - GitHub - jadore801120/attention-is-all-you-need-pytorch: A PyTorch implementation of the Transformer mo...

github.com

Transformer ๊ตฌ์กฐ๋ฅผ ์ œ์•ˆํ•œ "Attention is All you Need"๋Š” 2017๋…„์— ๋ฐœํ‘œ๋œ ๊ฐ€์žฅ ํฅ๋ฏธ๋กœ์šด ๋…ผ๋ฌธ ์ค‘ ํ•˜๋‚˜์ž…๋‹ˆ๋‹ค.

๊ฐ์ข… ๊ธฐ๊ณ„๋ฒˆ์—ญ ๋Œ€ํšŒ์—์„œ ์„ธ๊ณ„ ์ตœ๊ณ ์˜ ๊ธฐ๋ก ๋ณด์œ ํ•˜๊ณ  ์žˆ๊ณ 

BERT, GPT ๋“ฑ ์ตœ์‹  ์–ธ์–ด AI์•Œ๊ณ ๋ฆฌ์ฆ˜์€ ๋ชจ๋‘ Transformer ์•„ํ‚คํ…์ฒ˜ ๊ธฐ๋ฐ˜์œผ๋กœ ์‚ผ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.


 

Transformer๋Š” ์–ด๋–ป๊ฒŒ ํƒ„์ƒํ•˜๊ฒŒ ๋˜์—ˆ์„๊นŒ?

  • RNN์˜ long-term dependency๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•œ ๋…ธ๋ ฅ์œผ๋กœ LSTM, GRU ๋„์ž…
  • Seq2Seq์—์„œ ๋จผ ๊ฑฐ๋ฆฌ์— ์žˆ๋Š” ํ† ํฐ ์‚ฌ์ด์˜ ๊ด€๊ณ„๋ฅผ ๋ชจ๋ธ๋ง(hidden state)ํ•˜๊ธฐ ์œ„ํ•ด Attention๊ณผ ๊ฐ™์€ ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๋“ฑ์žฅ
  • Attention์„ ํ†ตํ•ด ๋จผ ๊ฑฐ๋ฆฌ์— ์žˆ๋Š” ๋ฌธ๋งฅ ์ •๋ณด๋„ ๊ฐ€์ง€๊ณ  ์™€์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋˜์—ˆ์ง€๋งŒ
    ์‹œํ€€์Šค ํ† ํฐ์„ ํƒ€์ž„์Šคํ…๋ณ„๋กœ ํ•˜๋‚˜์”ฉ ์ฒ˜๋ฆฌ(์ฆ‰, Sequentialํ•œ ์ž…๋ ฅ๊ฐ’์ด ์ฃผ์–ด์ ธ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ ๋ถˆ๊ฐ€๋Šฅ)ํ•ด์•ผ ํ•˜๋Š” RNN์˜ ํŠน์„ฑ ์ƒ, ๋Š๋ฆฌ๋‹ค๋Š” ๋‹จ์  ์—ฌ์ „ํžˆ ํ•ด๊ฒฐ๋˜์ง€ ์•Š์Œ
    • ์‹œํ€€์Šค ํ† ํฐ์„ ํƒ€์ž…์Šคํ…๋ณ„๋กœ ํ•˜๋‚˜์”ฉ ์ฒ˜๋ฆฌ =  ๋ชจ๋“  ํƒ€์ž„ ์Šคํ…์—์„œ hidden์ด ๊ณ„์‚ฐ๋œ ํ›„์—์•ผ attention์ด ์ง„ํ–‰ ๊ฐ€๋Šฅ
      Sequentialํ•œ ์ž…๋ ฅ๊ฐ’์ด ์ฃผ์–ด์ ธ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ ๋ถˆ๊ฐ€๋Šฅ 

 

RNN์˜ ๋‹จ์ ์ธ ๋Š๋ฆฌ๋‹ค๋Š” ๊ฒƒ๊ณผ long-term dependency ๋ชจ๋ธ๋ง ์–ด๋ ต๋‹ค๋Š” ์ ์„ ๋™์‹œ์— ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์œผ๋กœ ๊ณ ์•ˆ๋œ ๊ฒƒ์ด ๋ฐ”๋กœ Transformer์ž…๋‹ˆ๋‹ค. Seq2Seq๋ชจ๋ธ์— RNN์„ ์ œ๊ฑฐํ•˜๊ณ   "Attention Mechanism"์œผ๋กœ ์—ฐ๊ฒฐํ•˜๊ฒŒ ๋œ๋‹ค.
๋‹ค์‹œ ๋งํ•ด, 

(1) RNN์—†์ด ์ž…๋ ฅ๋œ ์‹œํ€€์Šค์— ์žˆ๋Š” ์ •๋ณด๋ฅผ ์ž˜ ๋ชจ๋ธ๋งํ•˜๊ณ 
(2) ์ฃผ์–ด์ง„ ๋ฌธ๋งฅ ๋‚ด์˜ ๋ชจ๋“  ์ •๋ณด๋ฅผ ๊ณ ๋ คํ•ด ์ž์—ฐ์–ด ํ† ํฐ๋“ค์˜ ์ •๋ณด๋ฅผ ๋ชจ๋ธ๋งํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ,

(3) ๋ณ‘๋ ฌ์ฒ˜๋ฆฌ๊ฐ€ ๊ฐ€๋Šฅํ•ด ์†๋„๊ฐ€ ๋น ๋ฅธ ๋ชจ๋ธ์ด๋‹ค.

 

 

Transformer์˜ ํ•ต์‹ฌ

  • ์ž…๋ ฅ๋ฐ์ดํ„ฐ๋ผ๋ฆฌ์˜ Self attention์„ ์‚ฌ์šฉํ•ด Recurrent Unit ์—†์ด๋„ ๋ฌธ์žฅ์„ ๋ชจ๋ธ๋งํ•˜์—ฌ
    ๋ฌธ์žฅ ๋‚ด ๋‹จ์–ด๋“ค์ด ์„œ๋กœ ์ •๋ณด๋ฅผ ํŒŒ์•…ํ•˜๋ฉฐ, ํ•ด๋‹น ๋‹จ์–ด์™€ ์ฃผ๋ณ€ ๋‹จ์–ด๊ฐ„์˜ ๊ด€๊ณ„, ๋ฌธ๋งฅ์„ ๋” ์ž˜ ํŒŒ์•…ํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋จ
  • Self Attention : Attention์„ ํ†ตํ•ด ๊ฐ ํ† ํฐ์ด ์ง‘์ค‘ํ•ด์•ผ ํ•  ๋ฌธ๋งฅ ์ •๋ณด๋ฅผ ์ ์ˆ˜ ๋งค๊ธฐ๊ณ , ๊ฐ€์ค‘ํ•ฉ
์ด์ „ RNN ๊ณ„์—ด Transformer
ํƒ€์ž„์Šคํ…๋ณ„๋กœ ๋“ค์–ด์˜ค๋Š” ์ธํ’‹์„ ์ด์ „ ์ •๋ณด์™€ ๊ฒฐํ•ฉํ•ด 
hidden representation ์ƒ์„ฑ
(์ด์ „ ์Šคํ…์ด ๋ชจ๋‘ ์ฒ˜๋ฆฌ๋˜์–ด์•ผ ํ•จ)
Self attention์„ ํ†ตํ•ด
ํ•œ ๋ฒˆ์— ๋ฌธ๋งฅ ๋‚ด์— ์žˆ๋Š” ๋ชจ๋“  ํ† ํฐ ์ •๋ณด๋ฅผ ๊ณ ๋ คํ•ด

hidden representation ์ƒ์„ฑ
๊ฐ ํ† ํฐ์ด sequentialํ•˜๊ฒŒ RNN cell์— ์ž…๋ ฅ๋˜๋ฏ€๋กœ
์ž…๋ ฅ ์‹œํ€€์Šค๋“ค์˜ ์ˆœ์„œ(์œ„์น˜) ์ •๋ณด๊ฐ€ ๋ณด์กด
์‹œํ€€์Šค(ํ† ํฐ)์„ ํ•œ๋ฒˆ์— ๋‹ค ์ž…๋ ฅํ•˜๋Š” ํ˜•ํƒœ์ด๋ฏ€๋กœ
seq2seq์™€ ๋‹ค๋ฅด๊ฒŒ ์ˆœ์„œ(์œ„์น˜)์ •๋ณด๊ฐ€ ๋ณด์กด๋˜์ง€ ์•Š์Œ

 

Transformer์˜ ์•„ํ‚คํ…์ฒ˜ - # 0.  Summary

  • RNN์ด๋‚˜ CNN์—†์ด ์˜ค๋กœ์ง€ Attention๊ณผ Dense layer(feed forward network)๋งŒ์œผ๋กœ ์ธํ’‹์„ ์—ฐ๊ฒฐํ•œ ๊ตฌ์กฐ
  • Transformer๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ ๋ฒˆ์—ญ ํƒœ์Šคํฌ๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•œ ๋ชจ๋ธ์ด๋ฏ€๋กœ Encoder - Decoder ๊ตฌ์กฐ๋ฅผ ์ง€๋‹Œ Seq2Seq model
  • ๊ฐ Encoder์™€ Decoder๋Š” L๊ฐœ์˜ ๋™์ผํ•œ Block์ด stack๋œ ํ˜•ํƒœ๋ฅผ ์ง€๋‹˜
  • Encoder Block ๊ตฌ์„ฑ
    • Multi-head self-attention
    • Position-wise feed-forward network
    • Residual connection
    • Layer Normalization
  • Encoder ๊ตฌ์กฐ Summary 
    • Encoder์˜ Input๊ณผ Output ํฌ๊ธฐ ๋™์ผ
    • [Input Embedding] ์ธํ’‹ ์‹œํ€€์Šค๋ฅผ ์ž„๋ฒ ๋”ฉ
      [Positional Encoding]
      ์œ„์น˜ ์ •๋ณด๋ฅผ ๋”ํ•ด์ฃผ๊ณ 
      [Multi-head attention]
      scale dot-product attention์„ ํ†ตํ•ด ์ธํ’‹์— ์žˆ๋Š” ๋ชจ๋“  ๋ฌธ๋งฅ ์ •๋ณด๋ฅผ ์‚ฌ์šฉํ•ด ํ† ํฐ์˜ representation์ƒ์„ฑ
      [Residual Connection] [Layer Normalization]
      [Feed Forward]
      Fully-Connected Network๋ฅผ ์ด์šฉํ•ด ์ •๋ณด๋ฅผ ๋ณ€ํ™˜
    • ์ธ์ฝ”๋”์—์„œ๋Š” ์ธํ’‹ ๋ฌธ์žฅ ์•ˆ์— ์žˆ๋Š” ํ† ํฐ๋“ค๊ฐ„์˜ ๊ด€๊ณ„๋ฅผ ๊ณ ๋ คํ•˜๋Š” self-attention์‚ฌ์šฉ
    • ์ธ์ฝ”๋”ฉ ๋‹จ๊ณ„์—์„œ๋Š” [Self-attention + Feed Forward] ๋ ˆ์ด์–ด๋ฅผ ์—ฌ๋Ÿฌ ์ธต ์Œ“์•„ ์ธํ’‹ ํ† ํฐ๋“ค์— ๋Œ€ํ•œ representation ์ƒ์„ฑ
    • ์ธ์ฝ”๋” layer #l์ด Output = ์ธ์ฝ”๋” layer #l+1์˜ Input
    • ๋งˆ์ง€๋ง‰ layer์˜ Output์€ Decoder์—์„œ Attention์— ์‚ฌ์šฉ
  • Decoder ๊ตฌ์กฐ Summary
    • ๊ธฐ๋ณธ์ ์ธ Encoder Block ๊ตฌ์„ฑ๊ณผ ๋™์ผ
    • Multi-head self-attention๊ณผ Position-wise FFN ์‚ฌ์ด์— Cross-attention(encoder-decoder attention) ๋ชจ๋“ˆ ์ถ”๊ฐ€
    • attention ํ†ตํ•ด์„œ ํ˜„์žฌ output์— ๋Œ€ํ•ด ๋ชจ๋“  ์ธ์ฝ”๋”์˜ hidden state๋ฅผ ๊ณ ๋ คํ•  ์ˆ˜ ์žˆ์Œ
    • ๋””์ฝ”๋”ฉ ์‹œ์— ๋ฏธ๋ž˜ ์‹œ์ ์˜ ๋‹จ์–ด ์ •๋ณด๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด Masked self-attention์„ ์‚ฌ์šฉ
    • ๋””์ฝ”ํ„ฐ์—์„œ๋Š” ๋””์ฝ”๋” ํžˆ๋“ ๊ณผ ์ธ์ฝ”๋” ํžˆ๋“ ๋“ค๊ฐ„์˜ attention์„ ๊ณ ๋ คํ•ด ํ† ํฐ์„ ์˜ˆ์ธกํ•จ
    • ๋””์ฝ”๋”ฉ ๋‹จ๊ณ„์—์„œ๋Š” [๋””์ฝ”๋” ํžˆ๋“ ์‚ฌ์ด์˜ Self Attention , ์ธ์ฝ”๋” ์•„์›ƒํ’‹๊ณผ์˜ attention]
    • ๋””์ฝ”๋” layer #l์ด Output = ๋””์ฝ”๋” layer #l+1์˜ Input
    • ๋งˆ์ง€๋ง‰ layer์˜ Output์€ ์ตœ์ข… ์˜ˆ์ธก์— ์‚ฌ์šฉ
  • head๋งˆ๋‹ค input ๋ฌธ์žฅ์˜ ์–ด๋””์— ์ง‘์ค‘ํ• ์ง€๊ฐ€ ๋‹ฌ๋ผ ๊ฐ€๊ณต์ด ๊ฐ€๋Šฅํ•˜์—ฌ ํ›จ์”ฌ ์„ฑ๋Šฅ์ด ์ข‹์Œ

 

 

Transformer ์•„ํ‚คํ…์ฒ˜

 

 

๋‹ค์Œ์œผ๋กœ ์•„ํ‚คํ…์ฒ˜์— ๋Œ€ํ•ด ๋” ์„ธ๋ถ€์ ์ธ ์„ค๋ช…์„ ํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

 

Transformer์˜ ์•„ํ‚คํ…์ฒ˜ - # 1.  Encoder์˜ Input

  • ๋ฌธ์žฅ์˜ ๊ฐ ์‹œํ€€์Šค(ํ† ํฐ)๋ฅผ ํ•™์Šต์— ์‚ฌ์šฉ๋˜๋Š” ๋ฒกํ„ฐ๋กœ ์ž„๋ฒ ๋”ฉํ•˜๋Š” embedding layer ๊ตฌ์ถ•
  • ์ž„๋ฒ ๋”ฉ ํ–‰๋ ฌ W์—์„œ look-up์„ ์ˆ˜ํ–‰ํ•˜์—ฌ ํ•ด๋‹นํ•˜๋Š” ๋ฒกํ„ฐ๋ฅผ ๋ฐ˜ํ™˜
  • W : ํ›ˆ๋ จ ๊ฐ€๋Šฅํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ

 

 

Transformer์˜ ์•„ํ‚คํ…์ฒ˜ - # 2.  Encoder์˜ Positional Encoding

  • Transformer๋Š” input์ž์ฒด๋ฅผ Sequentialํ•˜๊ฒŒ ์ž…๋ ฅํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ
    ํ•˜๋‚˜์˜ matrix๋กœ ํ•œ๋ฒˆ์— ๋ฐ›์•„ attention์„ ๊ณ„์‚ฐํ•˜๋Š” ํ˜•ํƒœ์ด๋ฏ€๋กœ,
    โ–ถ๊ฐ ํ† ํฐ์ด ๋ณ‘๋ ฌ์ ์œผ๋กœ ๋˜๋Š” ๋…๋ฆฝ์ ์œผ๋กœ ์—ฐ์‚ฐํ•˜๋Š” ๊ณผ์ •์„ ๊ฑฐ์น˜๋ฏ€๋กœ,
      Sequence๊ฐ€ ๋‚ดํฌํ•˜๊ณ  ์žˆ๋Š” ์ˆœ์„œ ์ •๋ณด์˜ ํŠน์ง•์„ ๋ฐ˜์˜ํ•˜์ง€ ๋ชปํ•จ
  • [์˜ˆ์‹œ] ๋งŒ์•ฝ [์ฒ ์ˆ˜] [๊ฐ€] [์˜ํฌ] [๋ฅผ] [์ข‹์•„ํ•ด]๋ผ๋Š” ์ธํ’‹์„ self-attention๋ ˆ์ด์–ด ๋„ฃ์œผ๋ฉด ์–ด๋–ค ๊ฒฐ๊ณผ?
               [์˜ํฌ] [๊ฐ€] [์ฒ ์ˆ˜] [๋ฅผ] [์ข‹์•„ํ•ด] ๋ผ๋Š” ๋ฌธ์žฅ์„ ๋„ฃ์€๋‹ค๋ฉด? 
              ๊ฐ™์€ ํ† ํฐ์ž„๋ฒ ๋”ฉ ๋งคํŠธ๋ฆญ์Šค์™€ ๊ฐ€์ค‘์น˜ ๋งคํŠธ๋ฆญ์Šค(Q, K, V)์— ๋Œ€ํ•ด ๋‘ ๋ฌธ์žฅ์€ ์™„์ „ํžˆ ๊ฐ™์€ ๋ฌธ์žฅ์œผ๋กœ ํ‘œํ˜„๋จ


  • ํ† ํฐ์˜ ์ˆœ์„œ์ •๋ณด๋ฅผ ๋ฐ˜์˜ํ•ด์ฃผ๊ธฐ ์œ„ํ•ด Position Encoding ์‚ฌ์šฉ
  • Positional Embedding์„ Input Embedding์— ๋”ํ•ด ์ตœ์ข… input์œผ๋กœ ์‚ฌ์šฉ
    (Input โ–ถ Embeddings + Positional Encoding = Embedding with Time signal โ–ถ Encoder)
  • Positional Encoding ์ฐจ์› = Embedding ์ฐจ์›
  • [์กฐ๊ฑด]
    • ๊ฐ ํ† ํฐ์˜ ์œ„์น˜๋งˆ๋‹ค ์œ ์ผํ•œ ๊ฐ’์„ ์ง€๋…€์•ผ ํ•จ
    • ํ† ํฐ ๊ฐ ์ฐจ์ด๊ฐ€ ์ผ์ •ํ•œ ์˜๋ฏธ๋ฅผ ์ง€๋…€์•ผ ํ•จ (๋ฉ€์ˆ˜๋ก ํฐ ๊ฐ’)
    • ๋” ๊ธด ๊ธธ์ด์˜ ๋ฌธ์žฅ์ด ์ž…๋ ฅ๋˜์–ด๋„ ์ผ๋ฐ˜ํ™”๊ฐ€ ๊ฐ€๋Šฅํ•ด์•ผ ํ•จ
  •  sin & cos ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด ํ† ํฐ์˜ ์ ˆ๋Œ€ ์œ„์น˜(0๋ฒˆ์งธ, 1๋ฒˆ์งธ)์— ๋”ฐ๋ฅธ ์ธ์ฝ”๋”ฉ์„ ๋งŒ๋“ค์–ด๋ƒ„
    (
    sin&cos ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด ์ธ์ฝ”๋”ฉ ํ•˜๋Š” ๋ฐฉ๋ฒ•์€ 'sinusoidal positional encoding'์ด๋ผ๊ณ  ๋ถ€๋ฆ„)

  • ์œ„์น˜์— ๋”ฐ๋ผ ์กฐ๊ธˆ์”ฉ ๋‹ค๋ฅธ ํŒŒํ˜•์ด ํ† ํฐ ์ž„๋ฒ ๋”ฉ์— ๋”ํ•ด์ง€๊ณ , ์ด๋ฅผ ํ†ตํ•ด ํŠธ๋žœ์Šคํฌ๋จธ ๋ธ”๋ก์€
    ์œ„ [์˜ˆ์‹œ]์˜ ๋‘ ๋ฌธ์žฅ์„ ๋‹ค๋ฅด๊ฒŒ ์ธ์ฝ”๋”ฉํ•  ์ˆ˜ ์žˆ์Œ

Positional Encoding ํŒŒํ˜•

  • ์ž„์˜์˜ ์ธ์ฝ”๋”ฉ์„ ๋งŒ๋“ค์–ด๋‚ด๋Š” ๋ฐฉ์‹ ๋Œ€์‹  Positional Encoding์— ํ•ด๋‹นํ•˜๋Š” ๋ถ€๋ถ„๋„ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์Œ
    • ์œ„์น˜๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” one-hot-vector๋ฅผ ๋งŒ๋“ค๊ณ , ์ด์— ๋Œ€ํ•ด ํ† ํฐ ์ž„๋ฒ ๋”ฉ๊ณผ ๋น„์Šทํ•˜๊ฒŒ ์œ„์น˜ ์ž„๋ฒ ๋”ฉ์„ ์ง„ํ–‰

 

 

 

Transformer์˜ ์•„ํ‚คํ…์ฒ˜ - # 3.  Encoder์˜ Scaled Dot Product Attention

  • Multi-head-self-attention์—์„œ ์‚ฌ์šฉํ•˜๋Š” Scaled dot-product attention์ด ํ•ต์‹ฌ
  • ๋ฌธ์žฅ ๋‚ด ์–ด๋–ป๊ฒŒ ๊ด€๊ณ„๋ฅผ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์„๊นŒ์˜ ๋‹ต!

Scaled Dot-Product Attention

  • dot-product๋กœ attention ์ ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ  ๊ทธ ์ ์ˆ˜๋ฅผ scaling
  • Step1) Attention ๋Œ€์ƒ์ด ๋˜๋Š” ํ† ํฐ๋“ค์„ Key์™€ Value, attentionํ•˜๋Š” ํ† ํฐ์„ Query๋กœ ๋ณ€ํ™˜(ํ–‰๋ ฌ๊ณฑ)
    • ์ธํ’‹ ๋ฌธ์žฅ์˜ ํ† ํฐ ์ž„๋ฒ ๋”ฉ์„ ํ†ตํ•ด ๊ฐ๊ฐ์˜ ํ† ํฐ์€ h์ฐจ์›์œผ๋กœ ๋ณ€ํ™˜ํ•œ ํ›„,
      ๋งคํŠธ๋ฆญ์Šค ๊ณฑ์„ ํ†ตํ•ด ๊ฐ๊ฐ์˜ ํ† ํฐ์— ๋Œ€ํ•ด 3๊ฐ€์ง€ ๋ฒกํ„ฐ ์ƒ์„ฑ โ–ถQuery-Key-Value๋กœ ๋™์ž‘
      (L : ์‹œํ€€์Šค ๊ธธ์ด, d : attention depth)
  Query Key Value
  ๊ด€์‹ฌ ์žˆ๋Š” ๋ฌธ๋งฅ ์ •๋ณด๋ฅผ ๊ฐ€์ง€๊ณ 
์˜ค๊ณ ์ž ํ•˜๋Š” ํ† ํฐ ๋ฒกํ„ฐ
(Self attention์˜ ์ฃผ์ฒด)
Query์™€ ๋ฌธ์žฅ์˜ ๊ฐ ํ† ํฐ ์‚ฌ์ด์˜
attention score์„ ๊ณ„์‚ฐ(๋ชจ๋ธ๋ง)ํ•˜๊ธฐ ์œ„ํ•œ ๋ฒกํ„ฐ (Query์™€์˜ ๋น„๊ต ๋Œ€์ƒ) 
Key๊ฐ’๋“ค์„ Query์™€ Key๊ฐ„์˜ ์œ ์‚ฌ๋„์— ๋”ฐ๋ผ ๊ฐ€์ค‘ํ•ฉํ•˜์—ฌ Query์— ๋ฐ˜์˜ํ•  ๋•Œ
๊ฐ€์ค‘ํ•ฉ์˜ ๋Œ€์ƒ์œผ๋กœ ์‚ฌ์šฉํ•  ๊ฐ’
  ํ˜„์žฌ์˜ hidden state๊ฐ’ hidden state์™€ ์˜ํ–ฅ์„ ์ฃผ๊ณ  ๋ฐ›๋Š” ๊ฐ’
๊ฐ ํ† ํฐ์— ๋Œ€ํ•œ ์ƒˆ๋กœ์šด representation์œผ๋กœ, ์ •๋ณด๋ฅผ ๋‹ด๊ณ  ์žˆ๋Š” ๋ฒกํ„ฐ.
Key๊ฐ€ ๊ฐ–๊ณ  ์žˆ๋Š” ์‹ค์ œ ๊ฐ’
  ๊ฒ€์ƒ‰์–ด ๋‹ค๋ฅธ ์ฃผ๋ณ€์˜ ํ† ํฐ๋“ค
(Key1, Key2, ...)
Attention Output์€ Value์—
Attention Score์„ ๊ณฑํ•˜์—ฌ ๊ณ„์‚ฐ๋จ

 

 

[ํ‘œํ˜„์‹ ์˜๋ฏธ]
โ–ก (Dk) : Query๊ฐ€ Projectiongํ•˜๋Š” ์ฐจ์›
    (M) : Key์˜ Sequence Length
    (N) : Query์˜ Sequence Length 
โ–ก ๋‚ด์ ์„ ์œ„ํ•ด Query์™€ Key์˜ ์ฐจ์›์ด ๊ฐ™์•„์•ผ ํ•จ
โ–ก M์€ ๊ฐ Key์— ํ•ด๋‹นํ•˜๋Š” Value๊ฐ’์„ ๊ฐ€์ค‘ํ•ฉํ•˜์—ฌ ์‚ฌ์šฉ
โ–ก Projection๋œ ์ฐจ์›์€ ๋‹ฌ๋ผ๋„ ๋˜์ง€๋งŒ Row-wise Softmax๊ฐ’์— Value๊ฐ€ ๋‚ด์ ๋˜์–ด์•ผ ํ•˜๋ฏ€๋กœ, Key์˜ M์— ํ•ด๋‹นํ•˜๋Š” ๋ถ€๋ถ„๊ณผ ๊ฐ™์•„์•ผ ํ•จ

[๋ถ„์ž/๋ถ„๋ชจ ์˜๋ฏธ]
โ–ก [๋ถ„์ž] dot-product : Query์™€ Key์˜ ๋‚ด์ ์œผ๋กœ ์ด๋ฃจ์–ด์ง
โ–ก [๋ถ„๋ชจ] Softmaxํ•จ์ˆ˜์˜ gradient vanishing ๋ฌธ์ œ ์™„ํ™” ์œ„ํ•ด Scaling

[Dk]
โ–ก ์ดˆ๋ฐ˜์— ํ•™์Šต์ด ์ž˜ ๋˜๋„๋ก ํ•˜๊ธฐ ์œ„ํ•ด Dk๋กœ scaling
โ–ก
 multi-head ์ ์šฉํ•˜์ง€ ์•Š์„ ๋•Œ Dk = 512

โ–ก Dk๊ฐ€ ์ž‘์€ ๊ฒฝ์šฐ,  dop-product attention๊ณผ additive attention์€ ์œ ์‚ฌํ•œ ์„ฑ๋Šฅ
โ–ก Dk๊ฐ€ ํฐ ๊ฒฝ์šฐ, ๋‚ด์ ๊ฐ’์ด ๊ต‰์žฅํžˆ ์ปค์ง์— ๋”ฐ๋ผ ํŠน์ • softmax ๊ฐ’(๊ฐ€์ค‘์น˜)์ด 1์— ๊ทผ์‚ฌ๋˜์–ด,
    ์ด๋Š” gradient๊ฐ€ ์†Œ์‹ค๋˜๋Š” ๊ฒฐ๊ณผ๋ฅผ ๋‚ณ์Œ
โ–ก ํ•˜๋‚˜์˜ softmax๊ฐ’์ด 1์— ๊ฐ€๊นŒ์›Œ์ง€๋ฉด gradient๊ฐ€ ๋ชจ๋‘ 0์— ์ˆ˜๋ ดํ•˜๊ฒŒ ๋˜์–ด ํ•™์Šต ์†๋„๊ฐ€ ๊ต‰์žฅํžˆ ๋Š๋ ค์ง€๊ฑฐ๋‚˜ ํ•™์Šต์ด ์•ˆ ๋  ์ˆ˜ ์žˆ์Œ


[SoftMax]

โ–ก ๊ฐ ํ–‰๋ณ„ SoftMax๋ฅผ ํƒœ์›Œ, ๊ฐ ํ† ํฐ์ด ๋‹ค๋ฅธ ํ† ํฐ๋“ค๊ณผ ๊ฐ–๋Š” ๊ด€๊ณ„์ ์ธ ์˜๋ฏธ๋ฅผ ํ‘œํ˜„(๊ฐ€์ค‘์น˜)

Scaled dot-product attention

  • Step2) Query์— ๋Œ€ํ•ด ๊ฐ key๋“ค๊ณผ์˜ ๋‚ด์ ์„ ํ†ตํ•ด attention ๊ฐ€์ค‘์น˜ ๊ณ„์‚ฐ.
              ์ด ๋•Œ scale๋œ ๋ฒกํ„ฐ ๋‚ด์ ์— Softmax๋ฅผ ์ทจํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ 'ํ™•๋ฅ  ๋ถ„ํฌ'์™€ ๊ฐ™์ด ๋งŒ๋“ฆ
    • Query๋ฒกํ„ฐ(๊ฒ€์ƒ‰์–ด)์™€ ๊ฐ๊ฐ์˜ key๋ฒกํ„ฐ๋“ค์„ ๋‚ด์ (dot product)ํ•˜์—ฌ ์–ดํ…์…˜ ์ ์ˆ˜(๋น„์ค‘, ์œ ์‚ฌ๋„)๋ฅผ ๊ณ„์‚ฐ

(์˜ˆ์‹œ) [I] ํ† ํฐ ์ž…์žฅ์—์„œ [I] [am] [a] [men]์˜ ์ค‘์š”๋„๋ฅผ ๊ณ„์‚ฐ
- [I]์™€์˜ ์ ์ˆ˜ : 128, [am]๊ณผ์˜ ์ ์ˆ˜ : 32, [a]์™€์˜ ์ ์ˆ˜ : 32, [men]๊ณผ์˜ ์ ์ˆ˜๋Š” 128
- dot-product ์—ฐ์‚ฐ์„ ํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— query์™€ key๋ฒกํ„ฐ์˜ ์ฐจ์›์— ๋”ฐ๋ผ ์ˆซ์ž๋Š” ์•„์ฃผ ์ปค์ง

  โ–ถ ์ฟผ๋ฆฌ ๋ฒกํ„ฐ์˜ ์ฐจ์›์˜ squared-root๋ฅผ ์ทจํ•ด ์ด ์ ์ˆ˜๋ฅผ scaling. ๊ทธ ๊ฒฐ๊ณผ 16, 4, 4, 1
- ์ด ์ ์ˆ˜์— softmax๋ฅผ ์ทจํ•ด ํ™•๋ฅ ๊ฐ’๊ณผ ๋น„์Šทํ•œ Attention distribution์„ ๋งŒ๋“ฆ

 

  • Step3) ๊ฐ€์ค‘์น˜๋ฅผ ์ด์šฉํ•ด Value๋ฅผ ๊ฐ€์ค‘ํ•ฉํ•˜์—ฌ Query์˜ representation์„ ์—…๋ฐ์ดํŠธ
    (Attention ๋ถ„ํฌ์— ๋”ฐ๋ผ Value๋ฅผ ๊ฐ€์ค‘ํ•ฉํ•˜์—ฌ Attention Value(output)์„ ๊ตฌํ•จ)
    • ์œ„์—์„œ ๊ตฌํ•œ Attention distribution์„ ๊ฐ€์ค‘์น˜๋กœ ์‚ฌ์šฉํ•˜์—ฌ value ๋ฒกํ„ฐ๋ฅผ ๊ฐ€์ค‘ํ•ฉํ•จ
    • ํ† ํฐ์ด ์–ด๋–ค ์ •๋ณด์™€ ๊ด€๋ จ์žˆ์„์ง€ (WQ, WK), ์–ด๋–ค ์ •๋ณด๋ฅผ ๊ฐ€์ง€๊ณ  ์™€์•ผ ํ• ์ง€(WV)
    • Query๋ฒกํ„ฐ๊ฐ€ Key์™€์˜ ์œ ์‚ฌ๋„ ์ •๋ณด๋ฅผ ๋ฐ˜์˜ํ•œ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ˜•

(์˜ˆ์‹œ) ์ด๋ ‡๊ฒŒ ๊ณ„์‚ฐ๋œ ๋ฒกํ„ฐ๋Š” queryํ† ํฐ์ธ[I]์— ๋Œ€ํ•œ representation
- [I] [am] [a] [men]์— ๋Œ€ํ•œ ๊ฐ€์ค‘์น˜๊ฐ€ ๊ฐ๊ฐ 0.4, 0.1, 0.1, 0.4๋กœ ๊ณ„์‚ฐ๋˜์–ด
  ์™„์„ฑ๋œ ๋ฒกํ„ฐ๋Š” [I]์™€ [men]์˜ ์ •๋ณด๊ฐ€ ๋งŽ์ด ๋ฐ˜์˜๋œ ๋ฒกํ„ฐ

- ๊ฐ Key์˜ ๊ฐ€์ค‘์น˜ * Value์˜ ํ•ฉ
โ–ถ self attention์„ ํ†ตํ•ด [I] ํ† ํฐ์ด ์ „์ฒด ์ธํ’‹ ํ…์ŠคํŠธ์˜ ์ •๋ณด๋ฅผ ๋ฐ˜์˜ํ•˜์—ฌ
[men]์˜ ์˜๋ฏธ๊นŒ์ง€ ๊ฐ€์ง„ representation์ด ๋จ

 

(์˜ˆ์‹œ) query = 'Key1' | attention = {'key1' : 'value1', 'key2' : 'value2', 'key3' : 'value3'}  attention[query] = 'value2'

 

  • Attention์˜ Output์€ Encoder์˜ Input๊ณผ ๋™์ผํ•œ ํฌ๊ธฐ์˜ ํ…์„œ
  • padding์— ๋Œ€ํ•œ attention score๋Š” 0(Q*Kt๊ณ„์‚ฐ ์‹œ -1e9๋กœ ์ž…๋ ฅ)์œผ๋กœ, loss๊ณ„์‚ฐ์—์„œ๋„ ์ฐธ์—ฌ ์•ˆ ํ•จ
class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):

        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)

        return output,

 

 

Self-attention์˜ ์˜๋ฏธ

  • ์ธํ’‹ ์‹œํ€€์Šค ์ „์ฒด์— ๋Œ€ํ•ด attention์„ ๊ณ„์‚ฐํ•ด ๊ฐ ํ† ํฐ์˜ representation์„ ๋งŒ๋“ค์–ด ๊ฐ€๋Š” ๊ณผ์ •์œผ๋กœ,
    ์—…๋ฐ์ดํŠธ๋œ representation์€ ๋ฌธ๋งฅ ์ •๋ณด๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์Œ
  • ์˜ˆ๋ฅผ ๋“ค์–ด, "์ฐจ์€์šฐ๋Š” #๋…„์— ํƒœ์–ด๋‚ฌ๋‹ค. ๊ทธ๋Š” ์ตœ๊ทผ์— #์— ์ถœ์—ฐํ–ˆ๋‹ค"๋ผ๋Š” ์ธํ’‹์— ๋Œ€ํ•ด self-attention์„ ์ ์šฉํ•˜๋ฉด
    '๊ทธ'์— ํ•ด๋‹นํ•˜๋Š” representation์€ '์ฐจ์€์šฐ'์— ๋Œ€ํ•œ ์ •๋ณด๋ฅผ ๋‹ด๊ฒŒ ๋œ๋‹ค.
  • scale dot-product attention์€ matrix๋กœ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์–ด RNN์ฒ˜๋Ÿผ ์ด์ „ ํ† ํฐ์ด ์ฒ˜๋ฆฌ๋˜๊ธธ ๊ธฐ๋‹ค๋ฆด ํ•„์š”๊ฐ€ ์—†์Œ
  • ์ด๋Ÿฌํ•œ ๊ณผ์ •์€ ํƒ€์ž„์Šคํ…์— ๋”ฐ๋ฅธ ์ง€์—ฐ ์—†์ด ํ•œ๋ฒˆ์— ์—ฐ์‚ฐํ•˜๋ฏ€๋กœ,
    ์ธํ’‹์œผ๋กœ ๋“ค์–ด์˜จ ๋ชจ๋“  ๋ฌธ๋งฅ์˜ ์ •๋ณด๋ฅผ foregetting ์—†์ด ๋ฐ˜์˜ ๊ฐ€๋Šฅ
  • ์ธํ’‹ ํ…์ŠคํŠธ์˜ ๊ฐ ํ† ํฐ์— weight๋ฅผ ๊ณฑํ•ด Q, K, V ๋ฅผ ๋งŒ๋“ค๊ณ  ์ด ๋ฒกํ„ฐ๋ฅผ ์‚ฌ์šฉํ•ด ํ† ํฐ์˜ ๋ฌธ๋งฅ ์ •๋ณด๋ฅผ ์—…๋ฐ์ดํŠธํ•˜๋Š” ๊ณผ์ •

 

 

Transformer์˜ ์•„ํ‚คํ…์ฒ˜ - # 4.  Encoder์˜ Multi-head Attention

  • ๊ฐ ํ† ํฐ์— ๋Œ€ํ•ด ์ธํ’‹ ์ปจํ…์ŠคํŠธ๋กœ๋ถ€ํ„ฐ ๊ด€๋ จ๋œ ๋ฌธ๋งฅ ์ •๋ณด๋ฅผ ํฌํ•จํ•ด ์˜๋ฏธ๋ฅผ ๋ชจ๋ธ๋งํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉ
  • ์ „์ฒด ์ž…๋ ฅ ๋ฌธ์žฅ์„ ์ „๋ถ€ ๋™์ผํ•œ ๋น„์œจ๋กœ ์ฐธ๊ณ ํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ, ํ•ด๋‹น ์‹œ์ ์—์„œ ์˜ˆ์ธกํ•ด์•ผ ํ•  ๋‹จ์–ด์™€ ์—ฐ๊ด€์žˆ๋Š” ๋‹จ์–ด ๋ถ€๋ถ„์„ ์ข€ ๋” ์ง‘์ค‘ 
  • Attention head๋ณ„๋กœ ๊ฐ๊ธฐ ๋‹ค๋ฅธ ์ธก๋ฉด์—์„œ ํ† ํฐ๊ฐ„์˜ ๊ด€๊ณ„์™€ ์ •๋ณด๋ฅผ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์Œ. ์•ฝ๊ฐ„์˜ ์•™์ƒ๋ธ” ๊ฐœ๋…?
  •  Attention ๊ณ„์‚ฐ ๊ณผ์ •์„ ์—ฌ๋Ÿฌ weight๋ฅผ ์‚ฌ์šฉํ•ด ๋ฐ˜๋ณตํ•˜๊ณ  ๊ทธ ๊ฒฐ๊ณผ๋ฅผ concatํ•˜์—ฌ ์ตœ์ข… attention output ๊ณ„์‚ฐ
    (
    ์ด๋Š” CNN filter์„ ์—ฌ๋Ÿฌ ์žฅ ์‚ฌ์šฉํ•จ์œผ๋กœ์จ ์ด๋ฏธ์ง€์— ์žˆ๋Š” ๋‹ค์–‘ํ•œ ํŠน์„ฑ์„ ํฌ์ฐฉํ•˜๋Š” ๊ฒƒ์ฒ˜๋Ÿผ,ํ† ํฐ ์‚ฌ์ด์˜ ๋‹ค์–‘ํ•œ ๊ด€๊ณ„๋ฅผ ํฌ์ฐฉํ•˜๊ธฐ ์œ„ํ•จ์ž„)
  • Scaled dot-product attention(WQ, WK, WV ๋งคํŠธ๋ฆญ์Šค)์„ ํ•œ ๋ฒˆ์— ๊ณ„์‚ฐํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ
    ์—ฌ๋Ÿฌ ๊ฐœ์˜ head์—์„œ Self-Attention์ด ์ˆ˜ํ–‰๋˜์–ด ๊ณ„์‚ฐํ•จ

Multi-head Attention

 

[๊ณผ์ •]

  • N x Dm ํฌ๊ธฐ์˜ ์ž„๋ฒ ๋”ฉ์— ๋Œ€ํ•œ Q, K, V๋กœ head์˜ ์ˆ˜(ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ)๋งŒํผ  attention์„ ๊ณ„์‚ฐ
  • ๊ฐ head์˜ Output์€ Concatenate๋จ
  • concatenate๋œ ๊ฐ head๋“ค์˜ output์€ ์›๋ž˜ Dmodel์˜ ์ฐจ์›์œผ๋กœ ๋‹ค์‹œ Projection๋˜์–ด ๋‹ค์Œ layer์— ์ „๋‹ฌ
  • Step1) Sequence Embedding (Sequence Length x Embedding dimension(Dmodel)) โ–ถ
    Step2) Weight matrices for Q, K, V (head์ˆ˜ ๋งŒํผ) โ–ถ
    Step3) head ์ˆ˜๋งŒํผ attention ๊ณ„์‚ฐ โ–ถ
    Step4) ๊ฐ attention head์˜ output concatenate โ–ถ
    Step5) Dmodel ์ฐจ์›์œผ๋กœ ๋‹ค์‹œ Projection

(Step2)

  • Sequence Embedding์—์„œ Input sequence embedding ์ „์ฒด๋ฅผ Q, K, V๋ฅผ ๊ตฌํ•˜๋Š” Weight matrices๋ฅผ ํ†ตํ•ด
    ํŠน์ • ์ฐจ์›์„ ๊ฐ€์ง€๋Š” Q, K, V๋ฅผ matrices๋ฅผ ๋งŒ๋“ฆ
    (embedding dimension์„ head์˜ ์ˆ˜ ๋งŒํผ ์ชผ๊ฐœ์–ด attention์„ ๊ตฌํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ)
  • Sequence Embedding์„ Projection์‹œํ‚จ W๋ฅผ head๋ณ„๋กœ ๋‚˜๋‰˜์–ด sequence Embedding๊ณผ ๋‚ด์ ํ•จ
  • Concat๋œ ํ˜•ํƒœ์˜ Weight๋ฅผ ๋งŒ๋“ค์–ด๋‘๊ณ  ๊ฐ๊ฐ์˜ head ์— ๋”ฐ๋ผ ํ•ด๋‹น Weight๋ฅผ ๋‚˜๋ˆ„์–ด ์‚ฌ์šฉ
  • Q, K, V๋ฅผ projectionํ•˜๋Š” weight matrices๋Š” ์ฒ˜์Œ๋ถ€ํ„ฐ ๋…๋ฆฝ์ ์œผ๋กœ head์ˆ˜๋งŒํผ ์ •์˜ํ•˜์ง€ ์•Š๊ณ ,
     n_Head*d_k์˜ dimension์„ ๊ฐ–๋Š” weight matrix๋ฅผ ๋งŒ๋“  ํ›„ ๋‚˜๋ˆ ์„œ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ์‹

(Step5)

  • โ”concatํ•œ ๊ฐ head์˜ attention output์„ Dmodelํฌ๊ธฐ๋กœ Projection์‹œํ‚ค๋Š” ์ด์œ ๋Š”?
     ๐Ÿ—จ concat์„ ๊ตฌ์„ฑํ•˜๋Š” ๊ฐ’๋“ค์€ ์ˆœ์„œ๋Œ€๋กœ ๊ฐ head์˜ ํŠน์„ฑ์„ ๋ฐ˜์˜ํ•˜์—ฌ ์œ„์น˜์— ๋”ฐ๋ฅธ ๊ฐ head์˜ ํŠน์„ฑ ์ •๋ณด๋ฅผ
         W0 Projection์„ ํ†ตํ•ด mixํ•˜์—ฌ ์‚ฌ์šฉ
  • โ” Dmodel์˜ ๋™์ผํ•œ ์ฐจ์›์œผ๋กœ ๋‹ค์‹œ Projection์‹œํ‚ค๋Š” ์ด์œ ๋Š”?
    ๐Ÿ—จ ๊ธฐ๋ณธ์ ์œผ๋กœ ๊ฐ layer์˜ input, output์˜ ์ฐจ์›์„ ๋™์ผํ•˜๊ฒŒ ๊ฐ€์ ธ๊ฐ€๊ณ ์ž ํ•จ (Dmodel ์ฐจ์› ์œ ์ง€)

(์ฝ”๋“œ)

  • Transformer ํด๋ž˜์Šค d_model , n_head, d_k = d_model/n_head , d_v = d_model/n_head
  • ์•ž์„œ ์ƒ์„ฑํ•œ Weight Matrix์˜ ์ฐจ์›์„ ์กฐ์ •ํ•˜์—ฌ head์ˆ˜๋งŒํผ Q, K, V Weight์˜ ์˜์—ญ ๊ตฌ๋ถ„
  • ์ฆ‰, sequence_length x (n_head x d_k(v))์˜ ํฌ๊ธฐ๋ฅผ ๊ฐ–๋Š” weight matrix๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ
    n_head์ˆ˜๋งŒํผ attention์ด ๊ณ„์‚ฐ๋  ์ˆ˜ ์žˆ๋„๋ก ์ฐจ์›์„ ๋ณ€๊ฒฝํ•˜์—ฌ ์‚ฌ์šฉ = ๋ณ‘๋ ฌ์ ์œผ๋กœ ์‚ฌ์šฉ
  • Attention ๊ณ„์‚ฐ ํ›„ head์˜ output ์ฐจ์›์„ ๋‹ค์‹œ ์กฐ์ •ํ•˜์—ฌ concatenate๋œ matrix ์ƒ์„ฑ
  • concat ํ›„  W0 matrix๋ฅผ ํ†ตํ•ด d_model์˜ ์ฐจ์›์œผ๋กœ ๋‹ค์‹œ projection๋จ
    def forward(self, q, k, v, mask=None):

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        residual = q

        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: b x lq x n x dv
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        # Transpose for attention dot product: b x n x lq x dv
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)   # For head axis broadcasting.

        q, attn = self.attention(q, k, v, mask=mask)

        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        q = self.dropout(self.fc(q))
        q += residual

        q = self.layer_norm(q)

        return q, attn

 

Transformer์˜ ์•„ํ‚คํ…์ฒ˜ - # 5.  Encoder์˜ (Position-Wise) Feed Forward Network

(Position-wise) Feed Foward Network

  • Feed Forward๋Š” Dense Layer, Multi-head attention์—์„œ ๋‚˜์˜จ hidden state ์ •๋ณด๋ฅผ ๊ฐ€๊ณตํ•˜๋Š” ๋ ˆ์ด์–ด
  • Attention์„ ๊ฑฐ์นœ ํ›„ Feed Forward Network(nn.Linear)์— ํƒœ์›€
  • Fully connected feed-forward module๋ฅผ ์ ์šฉํ•˜๋Š” ๋ถ€๋ถ„
  • Position๋งˆ๋‹ค ์ฆ‰ ๊ฐœ๋ณ„ ๋‹จ์–ด๋งˆ๋‹ค ์ ์šฉ๋˜์–ด position-wise
  • ๊ฐ™์€ encoder layer ๋‚ด(ํ•œ ๋ธ”๋ก ๋‚ด ๋‹จ์–ด ๊ฐ„์˜) FFN์˜ parameter(w, b)๋Š” ๊ณต์œ ๋จ
  • ๊ฐ ํ† ํฐ์ด hidden layer๋กœ ํ•œ๋ฒˆ projection ๋˜๊ณ  ๋‹ค์‹œ output ์ฐจ์›(Input๊ณผ ๋™์ผ ์ฐจ์›)์œผ๋กœ ๋„คํŠธ์›Œํฌ ๊ฑฐ์นจ
  • ์„œ๋กœ ๋‹ค๋ฅธ head๋กœ๋ถ€ํ„ฐ ๋‚˜์˜จ ๋…๋ฆฝ์ ์ธ multi-head attention ๊ฒฐ๊ณผ(heads)๋ฅผ processํ•˜๋Š” ๊ฒƒ์„ ์˜๋ฏธ
  • FFN์˜ Input : ์ด์ „ layer์˜ Output,
  • ReLU : max(0, ๊ฐ’) 
  • ๋ณดํ†ต d_hid(D_f) > d_model(D_m)
class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid) # position-wise
        self.w_2 = nn.Linear(d_hid, d_in) # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):

        residual = x

        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual

        x = self.layer_norm(x)

        return x

 

 

Transformer์˜ ์•„ํ‚คํ…์ฒ˜ - # 6.  Encoder์˜ Residual Connection and Normalization

  • Gradient exploding / vanishing ๋ฌธ์ œ๋ฅผ ์™„ํ™”ํ•˜๊ณ , deepํ•œ ๋„คํŠธ์›Œํฌ๋ฅผ ์•ˆ์ •์ ์œผ๋กœ ํ•™์Šตํ•˜๊ธฐ ์œ„ํ•ด ๋„์ž…
  • Add & Norm์€ residual connection์„ ๋งŒ๋“ค๊ณ , ๋ ˆ์ด์–ด ์ •๊ทœํ™”๋ฅผ ํ†ตํ•ด ๋ชจ๋ธ์ด ์ •๋ณด๋ฅผ ์žƒ์–ด๋ฒ„๋ฆฌ์ง€ ์•Š๋„๋ก ํ•˜๊ธฐ ์œ„ํ•ด ์žˆ๋Š” ๋ ˆ์ด์–ด
  • Residual connection : ๋‹ค์Œ๊ณผ ๊ฐ™์ด layer์„ ๊ฑฐ์นœ ์ดํ›„์— layer ๊ฑฐ์น˜๊ธฐ ์ด์ „์˜ x๊ฐ’์„ ๋”ํ•ด์ฃผ๋Š” ๊ฒƒ์„ ์˜๋ฏธ
    • ์ด์ „ ๋ ˆ์ด์–ด์˜ ์•„์›ƒํ’‹์„ ๋‹ค์Œ ๋ ˆ์ด์–ด์— ๋‹ค์ด๋ ‰ํŠธ๋กœ ์ด์–ด์ฃผ์–ด์„œ layer๋ฅผ ๊นŠ๊ฒŒ ์Œ“์•„๋„ gradient vanishing์ด ์ผ์–ด๋‚˜์ง€ ์•Š๊ฒŒ ํ•˜๋Š” ๊ธฐ๋ฒ• 
    • ADD = Residual Connection, Input์„ W(Q, K, V)๋กœ Query, Key, Value๋ฅผ ์–ป๊ณ  Multi Head Attention์œผ๋กœ ์–ป์€ ํ–‰๋ ฌ์„ input๊ฐ’์„ ๋”ํ•ด์ฃผ๋Š” ๊ฒƒ

Residual Connection

  • Layer Normalization : ๊ฐ ์ค‘๊ฐ„์ธต์˜ ์ถœ๋ ฅ์„ ์ •๊ทœํ™”ํ•œ ๊ฒƒ, ์ •๊ทœํ™”ํ•œ ํ›„ ๋‹ค์Œ Feed Forward๋กœ ๋„˜๊ฒจ์คŒ
    (x : eng_seq_len * d_model ์ฐจ์›)

  • torch.nn.LayerNorm(normalized_shape = d_model, eps, ) : ๋งˆ์ง€๋ง‰ ์ฐจ์›์— ๋Œ€ํ•ด normalize

 

 

 

 

Transformer์˜ ์•„ํ‚คํ…์ฒ˜ - # 7.  Decoder

  • ์•ž์„œ ์ธ์ฝ”๋”์˜ ์ตœ์ข… Output(Key, Value)์ด Decoder์˜ Multi-head Attention์— ์ด์šฉ๋จ
  • ๊ธฐ๋ณธ์ ์ธ Encoder์™€ ๋น„์Šทํ•œ ๊ตฌ์กฐ
    • ๋ฌธ์žฅ์˜ ๊ฐ ์‹œํ€€์Šค(ํ† ํฐ)์„ one-hot encoding ํ›„ ํ•™์Šต์— ์‚ฌ์šฉ๋˜๋Š” ๋ฒกํ„ฐ๋กœ ์ž„๋ฒ ๋”ฉํ•˜๋Š” embedding layer ๊ตฌ์ถ•
    • ์ž„๋ฒ ๋”ฉ ํ–‰๋ ฌ W์—์„œ look-up์„ ์ˆ˜ํ–‰ํ•˜์—ฌ ํ•ด๋‹นํ•˜๋Š” ๋ฒกํ„ฐ๋ฅผ ๋ฐ˜ํ™˜(W : Trainable parameter)
    • Positional Encoding ๋„ ๋˜‘๊ฐ™์ด
class Decoder(nn.Module):
    ''' A decoder model with self attention mechanism. '''

    def __init__(
            self, n_trg_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
            d_model, d_inner, pad_idx, n_position=200, dropout=0.1, scale_emb=False):

        super().__init__()

        self.trg_word_emb = nn.Embedding(n_trg_vocab, d_word_vec, padding_idx=pad_idx)
        self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
        self.dropout = nn.Dropout(p=dropout)
        self.layer_stack = nn.ModuleList([
            DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)])
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.scale_emb = scale_emb
        self.d_model = d_model

    def forward(self, trg_seq, trg_mask, enc_output, src_mask, return_attns=False):

        dec_slf_attn_list, dec_enc_attn_list = [], []

        # -- Forward
        dec_output = self.trg_word_emb(trg_seq)
        if self.scale_emb:
            dec_output *= self.d_model ** 0.5
        dec_output = self.dropout(self.position_enc(dec_output))
        dec_output = self.layer_norm(dec_output)

        for dec_layer in self.layer_stack:
            dec_output, dec_slf_attn, dec_enc_attn = dec_layer(
                dec_output, enc_output, slf_attn_mask=trg_mask, dec_enc_attn_mask=src_mask)
            dec_slf_attn_list += [dec_slf_attn] if return_attns else []
            dec_enc_attn_list += [dec_enc_attn] if return_attns else []

        if return_attns:
            return dec_output, dec_slf_attn_list, dec_enc_attn_list
        return dec_output,
  • Multi-head self-attention๊ณผ Position-wise FFN ์‚ฌ์ด์— Cross-attention(encoder-decoder attention) ๋ชจ๋“ˆ ์ถ”๊ฐ€
  • attention ํ†ตํ•ด์„œ ํ˜„์žฌ output์— ๋Œ€ํ•ด ๋ชจ๋“  ์ธ์ฝ”๋”์˜ hidden state๋ฅผ ๊ณ ๋ คํ•  ์ˆ˜ ์žˆ์Œ
  • ๋””์ฝ”ํ„ฐ์—์„œ๋Š” ๋””์ฝ”๋” ํžˆ๋“ ๊ณผ ์ธ์ฝ”๋” ํžˆ๋“ ๋“ค๊ฐ„์˜ attention์„ ๊ณ ๋ คํ•ด ํ† ํฐ์„ ์˜ˆ์ธกํ•จ
  • ๋””์ฝ”๋”ฉ ๋‹จ๊ณ„์—์„œ๋Š” [๋””์ฝ”๋” ํžˆ๋“ ์‚ฌ์ด์˜ Self Attention , ์ธ์ฝ”๋” ์•„์›ƒํ’‹๊ณผ์˜ attention]
  • [Input] dec_output : Query | enc_output : Key | enc_output : Value โ–ถ
    [Output]
    dec_output | attention(๋””์ฝ”๋”-์ธ์ฝ”๋” ํžˆ๋“  ๊ฐ„)

 

  • ๋””์ฝ”๋”ฉ ์‹œ์— ๋ฏธ๋ž˜ ์‹œ์ ์˜ ๋‹จ์–ด ์ •๋ณด๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด Masked self-attention์„ ์‚ฌ์šฉ
    (Self Attention Encder์—์„œ๋Š” ์†Œ์Šค๋งˆ์Šคํฌ, Decoder์—์„œ๋Š” ํƒ€๊ฒŸ๋งˆ์Šคํฌ ์‚ฌ์šฉ)
    (Attention ๊ฐ€์ค‘์น˜๋ฅผ ๊ตฌํ•  ๋•Œ ๋Œ€์ƒ ํ† ํฐ์˜ ์ดํ›„ ์‹œ์ ์˜ ํ† ํฐ์€ ์ฐธ์กฐํ•˜์ง€ ๋ชปํ•˜๋„๋ก ๋งˆ์Šคํ‚น)
  • diagonal ์œ—๋ถ€๋ถ„ ๋งˆ์Šคํ‚น
  • Padding token์˜ ์œ„์น˜์— ๋Œ€ํ•œ ๋งˆ์Šคํ‚น

  • Encoder์™€ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ head ์ˆ˜ ๋งŒํผ attention์„ ๊ตฌํ•˜๊ณ  concatenate(dec_seq_len x D_model)์„ ํ•œ ํ›„,
    ADD + Layer Normalization์„ ์ˆ˜ํ–‰
  • Decoder์˜ input๊ณผ ๋™์ผํ•œ ํฌ๊ธฐ์˜ tensor์˜ Output

  • ์ตœ์ข… layer์˜ output์€ ์ตœ์ข… ์˜ˆ์ธก์— ์‚ฌ์šฉ๋˜์–ด, Linear + Softmax์— ํƒœ์›€
	#####
	self.trg_word_prj = nn.Linear(d_model, n_trg_vocab, bias=False)
	#####

	def forward(self, src_seq, trg_seq):

        src_mask = get_pad_mask(src_seq, self.src_pad_idx)
        trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx) & get_subsequent_mask(trg_seq)

        enc_output, *_ = self.encoder(src_seq, src_mask)
        dec_output, *_ = self.decoder(trg_seq, trg_mask, enc_output, src_mask)
        seq_logit = self.trg_word_prj(dec_output)
        if self.scale_prj:
            seq_logit *= self.d_model ** -0.5

        return seq_logit.view(-1, seq_logit.size(2))

 

 

Transformer์˜ ์‹คํ—˜ ๊ฒฐ๊ณผ

  • (A) head๊ฐ€ ์ ๋‹นํžˆ ๋งŽ์€ ๊ฒฝ์šฐ ์„ฑ๋Šฅ ํ–ฅ์ƒ์„ ๊ธฐ๋Œ€ํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ง€๋‚˜์น˜๊ฒŒ ๋งŽ์•„์ง€๋ฉด ์˜คํžˆ๋ ค ์„ฑ๋Šฅ์ด ํ•˜๋ฝ
  • (B) Key์˜ ํฌ๊ธฐ๋ฅผ ์ค„์ด๋ฉด ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์ด ํ•˜๋ฝํ•  ์ˆ˜ ์žˆ์Œ
  • (C) ๋ชจ๋ธ์˜ ์‚ฌ์ด์ฆˆ๊ฐ€ ํด์ˆ˜๋ก ์„ฑ๋Šฅ์ด ํ–ฅ์ƒ๋˜๋Š” ๊ฒฝํ–ฅ์ด ์žˆ์Œ
  • (D) Regularization์˜ ์ผ์ข…์ธ Drop-Out๊ณผ Label Smoothing์ด ์„ฑ๋Šฅ ๊ฐœ์„ ์ด ํšจ๊ณผ์ ์ž„
    (Label Smoothing :๋ ˆ์ด๋ธ”์— ์˜ค๋ฅ˜๊ฐ€ ์กด์žฌํ•  ์ˆ˜ ์žˆ๊ณ  overconfidence๋ฅผ ์–ด๋Š ์ •๋„ ๋ฐฉ์ง€)

 

 

 

Transformer Summary

  • ์ธํ’‹ ๋ฌธ์žฅ์— ์žˆ๋Š” ๋ชจ๋“  ํ† ํฐ์˜ ์ •๋ณด๋ฅผ ํƒ€์ž„์Šคํ… ์ง„ํ–‰์— ๋”ฐ๋ฅธ forgetting์—†์ด ๋Œ์–ด์˜ค๋ฉด์„œ
    Positional Encoding์„ ํ†ตํ•ด ์œ„์น˜ ๊ฐ„์˜ ์ •๋ณด๊นŒ์ง€ ๋ชจ๋ธ๋งํ•  ์ˆ˜ ์žˆ๋Š” Transformer
  • [์žฅ์ ] RNN์„ ํ†ตํ•ด ๊ฐ ์Šคํ…์˜ hidden state์ด ๊ณ„์‚ฐ๋˜๊ธฐ๋ฅผ ๊ธฐ๋‹ค๋ฆฌ์ง€ ์•Š์•„๋„ ๋œ๋‹ค!
     ์ฆ‰, ๋ฌธ์žฅ์— ์žˆ๋Š” ๋ชจ๋“  ๋‹จ์–ด์˜ representation๋“ค์„ ๋ณ‘๋ ฌ์ ์œผ๋กœ ํ•œ๋ฒˆ์— ๋งŒ๋“ค ์ˆ˜ ์žˆ๊ณ ,
    ํ•™์Šต ์‹œ๊ฐ„ ๋‹จ์ถ•์— ๊ธฐ์—ฌ
  • GRU, LSTM ๊ฐ™์€ ์•„ํ‚คํ…์ฒ˜ ์—†์ด๋„ Long-term dependency๋ฅผ ํ•ด๊ฒฐํ•œ ์ƒˆ๋กœ์šด ๋ฐฉ์‹
  • ์ˆœ์ฐจ์  ๊ณ„์‚ฐ์ด ํ•„์š” ์—†๊ธฐ ๋•Œ๋ฌธ์— RNN๋ณด๋‹ค ๋น ๋ฅด๋ฉด์„œ๋„ ๋งฅ๋ฝ ํŒŒ์•…์„ ์ž˜ํ•˜๊ณ , 
    CNN์ฒ˜๋Ÿผ ์ผ๋ถ€์”ฉ๋งŒ์„ ๋ณด๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๊ณ  ์ „ ์˜์—ญ์„ ์•„์šฐ๋ฅธ๋‹ค.
  • [๋‹จ์ ] ์ดํ•ด๋ ฅ์ด ์ข‹์€ ๋Œ€์‹ ์— ๋ชจ๋ธ์˜ ํฌ๊ธฐ๊ฐ€ ์—„์ฒญ ์ปค์ง€๋ฉฐ ๊ณ ์‚ฌ์–‘์˜ ํ•˜๋“œ์›จ์–ด ์ŠคํŽ™์„ ์š”๊ตฌ
    โ–ถ์ด๋Ÿฌํ•œ ํ•œ๊ณ„๋ฅผ ๋ณด์™„ํ•˜๊ธฐ ์œ„ํ•œ ๋‹ค์–‘ํ•œ ๊ฒฝ๋Ÿ‰ํ•œ ๋ฐฉ์•ˆ์ด ์—ฐ๊ตฌ๋˜๊ณ  ์žˆ์Œ