基本思想

NLP中有很多sequence to sequence的问题,例如机器翻译,人机对话等等。对于句子而言,我们已经有RNN能够很好的处理序列之间的关系,但同时,RNN只能被用于输入和输出的维度都固定且已知的情况。但很多情况下,我们没办法确定输出序列的长度和维度。因此,为了处理这种general的序列问题,Seq2Seq框架被提出来了。

流程

seq2seq1

最基本的Seq2Seq框架主要的流程是:

  • 用一个LSTM来处理input的sequence,得到一个特定维度的向量表示,我们可以认为这个向量能很好的捕捉input中的相互关系。
    • 每一个timestep,cell将当前词的embedding向量和上一个hidden state进行concat作为输入,输出当前timestep的hidden state作为下一个cell的输入,依次进行,直到sentence的EOS标志符,得到最终的vector representation作为decoder的最初hidden state输入。
  • 用另一个LSTM,将这个vector representation映射成target sequence。每个timestep输出一个目标单词,直到输出EOS为止。
    • 接受来自上一个timestep的hidden state输入(最开始为vector representation),与上一个timestep的output进行concat作为当前timestep的输入,依次进行,直到最终生成的单词为EOS

缺点

  1. Encoder将输入编码为固定大小状态向量的过程实际上是一个信息有损压缩的过程,如果信息量越大,那么这个转化向量的过程对信息的损失就越大。
  2. 随着sequence length的增加,意味着时间维度上的序列很长,RNN模型也会出现梯度弥散,无法让Decoder关注时间间隔非常长的关联,精度下降。

Attention

  • 在普通的seq2seq模型中,我们输出的条件概率可以表示为:
    • $p(y_t| {y_1,..,y_{t-1}},c)=g(y_{t-1},s_t,c) $
    • 其中$s_t$表示$t$时刻的hidden state,$c$表示我们从Encoder学到的context vector,$g$表示非线性映射
  • 而在attention based seq2seq中,条件概率可以表示为:
    • $p(y_i|y_1,…,y_{i-1},x) = g(y_{i-1},s_i,c_i)$
    • hidden state表示为:$s_i = f(s_{i-1},y_{i-1},c_i)$
    • 也就是说,这里的每个单词$y_i$的条件概率都由不同的$c_i$决定,而不仅仅依赖于同一个$c$
  • 在Decoder中,对每个timestep,Input是上一个timestep的输出与特定的 $c_i$ (context vector)进行Attention后的结果,而hidden state和普通的seq2seq一样,为上一个timestep输出的hidden state
    • 实现Attention的方式有很多,例如直接点积,先concat后再进行线性变换等等。
  • 那么,现在关键的问题是,每一个$c_i$到底是如何计算的?
    • 论文中将Encoder的BiRNN产生的同一个timestep中两个hidden state进行concat组成一个annotations:$(h_1,…,h_{T_x})$,可以认为,每一个$h_i$都包含了在一个sequence中主要focus于第$i$个单词周围的相互关系
    • 因此,我们使用这种学习到的关系在不同的位置赋予不同的权重,来组成我们的context vector $c_i$:
      • $c_i = \sum\limits_{j=1}^{T_x} \alpha_{ij}h_j$
    • 每一个$\alpha_{ij}$是通过annotations $h_i$ 与上一个hidden state 计算出来,然后进入softmax函数得到当前位置的权重:
      • $\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^{T_x} \exp(e_{ik})}$
    • 这里的每一个$e_{ij}$衡量了在input的$j$位置和output的$i$位置的匹配程度,也就是论文中提到的alignment model,是由RNN前一个timestep的hidden state $s_{i-1}$和input的$h_j$计算出来的:
      • $e_{ij} = a(s_{i-1},h_j)$
      • 这里的$a$是一个简单的前向网络

流程

Encoder:

  • 与基本的Seq2Seq类似,用RNN/LSTM/GRU 来捕获序列间的相互关系,在论文中使用了BiRNN,因此同一个位置有两个hidden state,将其concat作为annotation:$h_j = [\overrightarrow{h_j^T};\overleftarrow{h_j^T}]^T$。

Decoder

  • 对每个timestep,Input是上一个timestep的输入与特定的 $c_i$ (context vector)进行Attention后的结果,而hidden state和普通的seq2seq一样,为上一个timestep输出的hidden state
    • Attention的方式有很多,例如直接点积,先concat后再进行线性变化等等。
  • 每一个 $c_i = \sum\limits_{j=1}^{T_x} \alpha_{ij}h_j$ ,是用Encoder的所有的annotaion进行加权求和得到的。根据权重$\alpha_{ij}$不同,也就体现了在不同位置,网络的注意力focus不同。
  • 而每一个权重$\alpha_{ij}$,是将上一个timestep的hidden state $s_{i-1}$和 $j$ 位置的annotation $h_j$ 放入一个简单的前馈神经网络中学习得到$e_{ij}$,再通过softmax转换为概率得到的:
    • $\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^{T_x} \exp(e_{ik})}$

本质

  • 目标句子生成的每个单词对应输入句子单词的概率分布可以理解为**输入句子单词和这个目标生成单词的对齐概率,**这在机器翻译语境下是非常直观的:传统的统计机器翻译一般在做的过程中会专门有一个短语对齐的步骤,而注意力模型其实起的是相同的作用。

    将Source中的构成元素想象成是由一系列的(Key,Value)数据对构成,此时给定Target中的某个元素Query,通过计算Query和各个Key的相关性,得到每个Key对应Value的权重系数,然后对Value进行加权求和,即得到了最终的Attention数值。

  • **所以本质上Attention机制是对Source中元素的Value值进行加权求和,而Query和Key用来计算对应Value的权重系数。**即可以将其本质思想改写为如下公式:

  • 在这里,Value可以理解为 $h_1,…h_{T_x}$,Query认为是 $s_{i-1}$,Key 同样是 $h_1,…h_{T_x}$

实践

Encoder

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

Decoder

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

Attention

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))

        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)

        output = F.relu(output)
        output, hidden = self.gru(output, hidden)

        output = F.log_softmax(self.out(output[0]), dim=1)
        return output, hidden, attn_weights

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

参考资料

  1. Sequence to Sequence Learning with Neural Networks
  2. Neural Machine Translation by Jointly Learning to Align and Translate
  3. Translate with seq2seq network and attention(pytorch)