Stephen Merity, Caiming Xiong,
James Bradbury, Richard Socher
Salesforce Research


 

What I want you to learn

  • Many RNN encoder-decoder models need to (poorly) relearn one-to-one token mappings

  • Recurrent dropout - use it (Gal or zoneout)

  • Pointer networks are really promising + interesting
    (pointer sentinel helps pull off joint learning)

  • Ensembles can do better via joint learning

  • Mikolov PTB has issues - know what they are

  • LM + BPTT training has issues

By accurately assigning probability to a natural sequence, you can improve many tasks:

Machine Translation
p(strong tea) > p(powerful tea)

Speech Recognition
p(speech recognition) > p(speech wreck ignition)

Question Answering / Summarization
p(President X attended ...) is higher for X=Obama

Query Completion
p(Michael Jordan Berkeley) > p(Michael Jordan basketball)

Regularizing RNN LMs

Zoneout (Krueger et al. 2016)
Offers two main advantages over "locked" dropout

+ "Locked" dropout means |h| x p of the hidden units are not usable (dropped out)
+ With zoneout, all of h remains usable
(important for lengthy stateful RNNs in LM)
+ Zoneout can be considered related to stochastic depth which might help train deep networks

Regularizing RNN LMs

Zoneout (Krueger et al. 2016)
Stochastically forces some of the recurrent units in h to maintain their previous values

Imagine a faulty update mechanism:


where delta is the update and m the dropout mask

h_t = h_{t-1} + m \odot \delta_t
ht=ht1+mδth_t = h_{t-1} + m \odot \delta_t

Regularizing RNN LMs

IMHO, though Gal et al. did it for LM, best for
recurrent connections on short sequences
(portions of the hidden state h are blacked out until the dropout mask is reset)

Regularizing RNN LMs

Variational LSTM (Gal et al. 2015)
Gives major improvement over standard LSTM, simplifies the prevention of overfitting immensely

Regularizing RNN LMs

Variational LSTM (Gal et al. 2015)
I informally refer to it as "locked dropout":
you lock the dropout mask once it's used

Progression of language models

h_t, y_t = \text{RNN}(x_t, h_{t-1})
ht,yt=RNN(xt,ht1)h_t, y_t = \text{RNN}(x_t, h_{t-1})

Encode previous relevant context from x,
store it for many time steps in h,
decode relevant context to y when appropriate

Progression of language models

h_t, y_t = \text{RNN}(x_t, h_{t-1})
ht,yt=RNN(xt,ht1)h_t, y_t = \text{RNN}(x_t, h_{t-1})

N-grams used to be all the rage

RNN variants, although computationally intensive, now hold the state of the art

Natural sequence ⇒ probability

p(S) = p(w_1, w_2, w_3, \ldots, w_n)
p(S)=p(w1,w2,w3,,wn)p(S) = p(w_1, w_2, w_3, \ldots, w_n)

Break this down to probability of next word via chain rule of probability

p(w_n|w_1, w_2, w_3, \ldots, w_{n-1})
p(wnw1,w2,w3,,wn1)p(w_n|w_1, w_2, w_3, \ldots, w_{n-1})

Given a sequence - in our case S is a sentence composed of the words W = [w1, w2, ..., wn]

Issues with RNNs for LM

Hidden state is limited in capacity
Vanishing gradient still hinders learning
Encoding / decoding rare words is problematic

Issues with RNNs for LM

The hidden state is limited in capacity
RNNs for LM do best with large hidden states
... but parameters increase quadratically with h
(attention tries to help with this but attention hasn't been highly effective for LM)

Issues with RNNs for LM

Vanishing gradient still hinders learning
"LSTMs capture long term dependencies"
... yet we only train BPTT for 35 timesteps?

Issues with RNNs for LM

Encoding / decoding rare words is problematic
+ In real life, majority of words are in the long tail

+ The word vectors need to get in sync with the softmax weights (learn a one-to-one mapping)
+ The RNN needs to learn to encode and decode rare words given very few examples
+Nothing that can be done when encountering out-of-vocabulary (OoV) words

Pointer Networks (Vinyals et al. 2015)

Convex Hull

Delaunay  Triangulation

The challenge: we don't have a good "vocabulary" to refer to our points (i.e. P1 = [42, 19.5])
We want to reproduce the exact point

Pointer Networks (Vinyals et al. 2015)

Pointer Networks (Vinyals et al. 2015)

Pointer Networks

Pointer networks avoid needing to learn to store the identity of the token to be produced
(works well on limited data and rare tokens!)

 

Why is this important?
This may help solve our rare / OoV problem

... but only if the word is in the input

Pointer Networks

Pointer networks also provide direct supervision

The pointer is directed to attend to any words in the input that are the correct answer
(explicitly instructed re: one-to-one mapping)

This helps gradient flow by providing a shortcut rather than having to traverse the full RNN

Attention Sum (Kadlec et al. 2016)

Pointer Networks


Big issue:
What if the correct answer isn't in our input..?
... then the pointer network can never be correct!

 

Really want to merge the rare / OoV advantages of the pointer with the vocabulary of the RNN

Joining RNN and Pointer

General concept is to combine two vocabularies:
vocabulary softmax (RNN) and positional softmax (ptr)

Gulcehre et al. (2016) use a "switch" that uses the hidden state of the RNN to decide between vocabulary and pointer

Major issue: all relevant info must be in the hidden state
(including (word, position) pairs for full window)

Pointer Sentinel

Use the pointer to determine how to mix the vocabulary and pointer softmax together

Pointer Sentinel

Why is the sentinel important?

The pointer will back off to the vocabulary if it's uncertain
(specifically, if nothing in the pointer window attracts interest)

"Degrades" to a straight up mixture model:
If g = 1, only vocabulary
If g = 0, only pointer

Pointer Sentinel

Any probability mass the pointer gives to the sentinel is passed to the vocabulary softmax

Why is the sentinel important?

The pointer decides when to back off to the vocab softmax

This is helpful as the RNN hidden state has limited capacity and can't accurately recall what's in the pointer,
especially after having seen large amounts of content

The longer the pointer's window becomes,
the more the issue of limited capacity may hurt us

We keep saying LSTMs do well for long term dependencies but ...

we train them with BPTT
for only 35 timesteps

Training the pointer sentinel model

RNN LMs by default use short windows for BPTT
(the default is only 35 timesteps!)

The pointer network wants as large a window as possible - it's not constrained by storing information in a hidden state
(the stored RNN outputs are essentially a memory network)

Issues in standard BPTT

BPTT for 0 ts

For following work, pointer sentinel BPTT uses 100 past timesteps

Slower to train, similar predict

Important to note that, while pointer sentinel mixture models are slower to train, their run time is very similar

Previous RNN outputs are fully cached

Total additional computation is minimal compared to standard RNN LM
(especially when factoring in h required for similar perplexity)

Training the pointer sentinel model

A relatively undiscussed issue with training RNN LMs ...
BPTT(35) means backprop for 35 ts, leap 35 ts, repeat...

When you use BPTT for 35 timesteps,
the majority of words have far less history

The word at the start of the window
only backprops into the word vector
(exacerbated by not shuffling batch splits in LM data)

The pointer helps the RNN's gradients

The pointer sentinel mixture model changes the gradient flow
The pointer helps the RNN as the gradient doesn't need to traverse many previous timesteps

The pointer helps the RNN's gradients

For the pointer to point to the correct output,
the RNN is encouraged to make the hidden states at the two locations (end and input target) similar

Training the pointer sentinel model

+ Safest tactic is just to regenerate all stale outputs
... but regenerating RNN outputs for last 100 words is slow
(working with stale outputs / gradients may be interesting ..?)

+ Aggressive gradient clipping is important early
(100 steps before any error signal ⇒ gradient explosion)

+ Plain SGD works better than ADAM etc (rarely discussed)
 (GNMT for example starts with ADAM then uses SGD)

Mikolov Penn Treebank (PTB)

Why is it used?
+ A standard benchmark dataset for LM
+ Article level rather than sentence level
(One Billion Word LM is sentence level)

Mikolov Penn Treebank (PTB)

+ Small, both in size and vocab (10k words)
+ All lowercase
+ No punctuation
+ Numbers replaced with N

Mikolov Penn Treebank (PTB)

Aside: character level Mikolov PTB is ... odd ...

 

 


Literally the word level dataset in character form

+ Entries from high quality Wikipedia articles
+ Larger vocabulary (only unk'ed if freq < 3)
+ Numbers, caps, and punctuation are retained

Dataset comparison

+ WikiText-2 and PTB are of similar size
+ WikiText-2 has a more realistic vocab long tail
+ WikiText-103 is over 1/10th One Billion LM
(this makes working with it non-trivial)
+ WikiText-2 & WikiText-103 have the same valid/test

Results (PTB)

Results (PTB)

10+% drop

Results (WikiText-2)

With a more realistic vocabulary, the results are even better

Pointer sentinel really helps rare words

Frequent

Rare

Progress continues..!

Independently Improving Neural Language Models with a Continuous Cache (Grave, Joulin, Usunier) apply a similar mechanism as Pointer Sentinel to RNN outputs
(they report results on PTB, WikiText-2, and WikiText-103!)

 

Tying word vectors helps with rare words and avoids wastefully learning a one-to-one mapping
(major perplexity improvement for essentially all models)


Recurrent Highway Network (Zilly, Srivastava, Koutník, Schmidhuber) & Neural Architecture Search with RL (Zoph, Le) are improving basic RNN cells with new SotA on PTB

Examples

the pricing will become more realistic which should help management said bruce rosenthal a new york investment banker
...
they assumed there would be major gains in both probability and sales mr. ???

Examples

gen. douglas <unk> demanded and got in addition to his u.n. command
...
<unk> 's ghost sometimes runs through the e ring dressed like gen. ???

Examples

the sales force is viewed as a critical asset in integrated 's attempt to sell its core companies <eos>
<unk> cited concerns about how long integrated would be able to hold together the sales force as one reason its talks with integrated failed <eos>
in composite trading on the new york stock exchange yesterday ???

Examples

federal researchers said lung-cancer mortality rates for people under N years of age have begun to decline particularly for white males <eos>

the national cancer institute also projected that overall u.s. ???

Examples

a merc spokesman said the plan has n't made much difference in liquidity in the pit <eos>

it 's too soon to tell but people do n't seem to be unhappy with it he ???

Examples

the fha alone lost $ N billion in fiscal N the government 's equity in the agency essentially its reserve fund fell to minus $ N billion <eos>

the federal government has had to pump in $ N ???

Conclusion

  • The pointer sentinel mixture model allows for improved perplexity and rare/OoV word handling at minimal cost

  • Consider the potential issues in using standard RNN LMs, both in training (BPTT) and data (Mikolov PTB)

  • WikiText-2 and WikiText-103 are a promising dataset for long term dependency + long tail language modeling

Conclusion

  • Find and fix information / gradient flow bottlenecks

  • Help your RNN encoder-decoder models relearn
    one-to-one token mappings (or tie the word vectors)

  • Recurrent dropout - use it! (Gal variational, zoneout, ...)

  • Pointer networks are really promising + interesting for OoV

  • Try joint learning of models that may usually be ensembles
    (pointer sentinel helps pull off joint learning)

  • Mikolov PTB has issues - at least know what they are

  • Standard LM + BPTT training is likely not optimal

NIPS Extreme Classification - Pointer Sentinel Mixture Models

By smerity

NIPS Extreme Classification - Pointer Sentinel Mixture Models

  • 2,637