這篇文章將會解釋Transformer做文本生成的運作機制,透過一步步拆解運算過程來瞭解背後運作的步驟,希望讓大家更加清楚文本生成背後的事情。最後附上一個兩個Code展示GPT2做文本生成的實例,來看看怎麼樣用TPU訓練自己的GPT2。

文本生成的過程

文本生成的過程是輸入前文輸出後文,然後一直迭代到結束。用例子來看,假設輸入明天天氣如何? 要輸出 晴天,就可以將過程寫成條件概率: $P(晴天|明天天氣如何)$,可以進一步分解成 $P(晴|明天天氣如何) \times P(天|明天天氣如何晴)$。模型所要訓練的目標,是最大化 $P(晴|明天天氣如何)$ 和 $P(天|明天天氣如何晴)$ 的機率,也就是讓他們越接近1越好。

所以說,模型需要運算出這兩項的機率,$P(晴|明天天氣如何)$ 和 $P(天|明天天氣如何晴)$。但麻煩的地方在於,這個運算不能同時進行,因為是要在輸入以後才能預測出來。後文的預測相依於前文的輸入,使得文本生成的過程不能平行。
但是,在訓練的時候,由於我們已經知道要輸出的全文,如例子的話就是晴天,也就不需要等前文預測結果出來才能做後文的預測,因為我們已經知道全文會是甚麼了。也就是說,我們可以同時求出$P(晴|明天天氣如何), P(天|明天天氣如何晴)$,而這種訓練方式會叫做 teacher forcing。

Transformer的平行化預測

Transformer的設計就是為了平行化以上的訓練。誠如我們能將輸入文本拆成 $P(晴|明天天氣如何), P(天|明天天氣如何晴)$,這樣一個兩個字的預測就拆成兩筆資料,當預測文本越長,一句話就被多長就被拆出多少筆資料,這樣訓練的資料量是龐大且可怕的。因此Transformer還有一件要做到的事情,用一筆輸入平行化預測文本。它的作法就是讓訓練目標變成用前一個字預測下一個字。
Transformer的訓練和預測會變成以下的形式,輸入的文本會是明天天氣如何<S>晴天,輸出呢,則是xxxxxx晴天<E>x是我們不管他預測甚麼都可以,而第一個target就是當我們遇到明天天氣如何<S>的時候,要從<S>預測到,然後以此類推,讓預測,讓
預測<E><S><E>則是代表預測的開始和結束。這樣我們就只有一筆資料就可以訓練出生成模型了。

Input    明天天氣如何 <S> 晴 天
Output              晴  天 <E>

Transformer decoder

以上的架構還有一個很嚴重的缺陷,模型一開始訓練loss就瘋狂下降,但預測的時候會甚麼都沒有。原因是因為前文能看到後文,輸入的<S>在預測的時候其實輸入的就在旁邊,多訓練幾次之後模型也會意識到直接抄旁邊的字來預測就好,這樣的模型學不會任何東西。這樣還需要一些設計,讓Transformer模型的輸入可以避免看到後文。

如果感覺印象不是清晰,關於Transformer的基本介紹可以看上一篇: Transformer

Transformer會將輸入的文本變成embedding向量,然後再將embedding投射到同等大小的三個向量空間裡面,分別的Q,K和V。這個部分的設計是為了讓Transformer模型裡面輸入能夠互相”看”到彼此,去決定哪一個部分對自己重要。

然後,為了知道那些字對自己而言是重要的,下一步就會將Q乘K,得到對於Q的token而言每一個k的重要性。Transformer裡面每一個字都可以輪流當作Q,所以最後Q*K的結果得出的矩陣會是 輸入文本的長度 * 輸入文本的長度,在輸入<S> 晴 天的例子來看,矩陣會是這樣的:(這裡為了簡化將所有的weight都設1)

然後,下一步是將算好的比例跟V相乘,得到對應字按照比例分配後的weight,v的維度會是 輸入文本的長度 * 特徵大小,特徵大小可能會是768,也就是這個字會用768維的weight來表示。我們放到其中的一個維度x來看,則會是這樣的:

也就是說,S的weight取決於QK所分配的比例,讓token之間可以互相看到的地方也就是在QK,所以我們可以在QK的地方讓前文不能看到後文,所做的方式其實就是在QK的結果之上加入一個causal mask,目的是讓不希望看到的字,weight都是0。

加入Causal Mask以後,做QK*V的時候,就會發現由於weight設0,<S>最終的weight就不會受到後文的影響。

以上就是Transformer decoder的設計,相比encoder會多了一個causal mask,使得前面的字不會看到後面,可以平行化訓練文本生成,其中很有名的模型就是GPT。

Transformer 的 encoder decoder

這裡可以補充一些Transformer的seq2seq模型,Transformer的decoder會有一個缺陷,就是前面的token不會得到後文的信息,因此在做summarization之類的任務時,輸入會是一篇文章,生成的目標是文章的摘要,這就需要對文章的整體有更好的理解。但是decoder會使得輸入的前文看不到後文,這會一定程度地影響到模型對於輸入的理解。為了讓輸入有更加完整的信息,更好的設計應該是讓輸入的文本能互相看到前後文,在預測的時候才遵循前文不能看到後文的規則。也就是因此,設計了encoder decoder架構。

Encoder和Decoder的互相溝通的地方會是cross attention這一塊,Transformer會將encoder的K和V傳到decoder,然後接續decoder的QKV算出最後的weight。

encoder decoder架構下,其中一個有代表的模型就是BART

UniLM

encoder decoder的架構,encoder是不會得到decoder的信息,只有decoder會能拿到encoder的信息。因此,還有一種設計其實更加簡單,讓encoder和decoder都能取得各種的信息,也能做到文本生成。其中UniLM就用到這個方法來訓練。
其實說起來也不是甚麼複雜的事情,我們在attention的casual mask上面下手腳就可以。在encoder部分的文本,我們希望他們是互相看到的,而在decoder部分則前文看不到後文,因此將causl mask這樣加上去就可以啦:

GPT實例

以上就是Transformer的各種文本生成的介紹,現在我們就嘗試一下他們的效果如何:

import & load model

from Transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSeq2SeqLM
import torch

tokenizer = AutoTokenizer.from_pretrained("ckiplab/gpt2-base-chinese")  
model = AutoModelWithLMHead.from_pretrained("ckiplab/gpt2-base-chinese")

建立資料

input_text="""
這不是BUG,是feature
"""

input_ids = tokenizer.encode(input_text, return_tensors='pt')

input_ids[:,:-1]
# input  - tensor([[  101,  3300,  8761,   679,  8024,  3221, 12605, 10461]])

input_ids[:,1:]
# target - tensor([[ 3300,  8761,   679,  8024,  3221, 12605, 10461,   102]])

訓練

runtime = 10
optim = torch.optim.AdamW(model.parameters(), lr=5e-5)
loss_fct = torch.nn.CrossEntropyLoss()  # -1 index = padding token
model.train()
for _ in range(runtime):
    logit = model(d_input_ids[:,:-1]).logits
    loss = loss_fct(logit.view(-1, model.config.vocab_size), input_ids[:,1:].view(-1))
    loss.backward()
    print(loss)
    optim.step()

預測

predict_input = tokenizer.encode("這不是bug,", return_tensors='pt')[:,:-1]
model.eval()
output = model.generate(predict_input)
tokenizer.decode(output[0])

結果
'[CLS] 這 不 是 bug , 是 feature [SEP]'

Colab

GPT+對話生成+TPU

既然我們大致瞭解GPT怎麼做生成,來個大一點的資料集訓練看看效果如何~
GPT+對話生成+TPU Colab

模型用來自ckip lab預訓練的GPT模型: ckiplab/gpt2-base-chinese

訓練資料來自zake大大的PTT中文語料
抓其中前100000筆資料訓練,用colab的TPU大約需要1個多小時訓練一個epoch,目前訓練了兩個epoch:

epoch: 0, iter: 0, loss: 6.885707 
epoch: 0, iter: 1000, loss: 4.627165 
epoch: 0, iter: 2000, loss: 5.038326 
epoch: 0, iter: 3000, loss: 4.707286 
epoch: 0, iter: 4000, loss: 4.463942 
epoch: 0, iter: 5000, loss: 4.513004 
epoch: 0, iter: 6000, loss: 4.098025 
epoch: 0, iter: 7000, loss: 4.284454 
epoch: 0, iter: 8000, loss: 4.385873 
epoch: 0, iter: 9000, loss: 4.173577 
epoch: 0, train loss: 4.376493, eval loss: 4.258459
--- 4278.577525138855 seconds ---
epoch: 1, iter: 0, loss: 4.186140 
epoch: 1, iter: 1000, loss: 4.342926 
epoch: 1, iter: 2000, loss: 3.838801 
epoch: 1, iter: 3000, loss: 4.286644 
epoch: 1, iter: 4000, loss: 4.098445 
epoch: 1, iter: 5000, loss: 3.389583 
epoch: 1, iter: 6000, loss: 3.970446 
epoch: 1, iter: 7000, loss: 3.928881 
epoch: 1, iter: 8000, loss: 4.200979 
epoch: 1, iter: 9000, loss: 4.149964 
epoch: 1, train loss: 3.996954, eval loss: 4.151195
--- 4252.381770372391 seconds ---

可以看到loss還有進一步下降的空間,由於時間關係,交給各位去嘗試了。

預測例子:
股 票 虧 爛 怎 麼 辦 ? 明 年 春 節 到 了 才 發 啊
情 人 節 該 帶 女 朋 友 去 哪 慶 祝 ? 肥 宅 沒 人 會 慶 沒 人 會 動

Reference

https://huggingface.co/blog/encoder-decoder
http://jalammar.github.io/illustrated-transformer/