Skip to content

Commit 3e85a4f

Browse files
committed
Working BiLSTM CRF but needs explanations
1 parent 5390694 commit 3e85a4f

File tree

1 file changed

+175
-29
lines changed

1 file changed

+175
-29
lines changed

Deep Learning for Natural Language Processing with Pytorch.ipynb

+175-29
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@
1313
},
1414
{
1515
"cell_type": "code",
16-
"execution_count": 1,
16+
"execution_count": 2,
1717
"metadata": {
1818
"collapsed": false
1919
},
2020
"outputs": [
2121
{
2222
"data": {
2323
"text/plain": [
24-
"<torch._C.Generator at 0x7f8c23c1a678>"
24+
"<torch._C.Generator at 0x7f3085495af8>"
2525
]
2626
},
27-
"execution_count": 1,
27+
"execution_count": 2,
2828
"metadata": {},
2929
"output_type": "execute_result"
3030
}
@@ -665,15 +665,14 @@
665665
"\n",
666666
"All network components should inherit from nn.Module and override the forward() method. That is about it, as far as the boilerplate is concerned. Inheriting from nn.Module provides functionality to your component. For example, it makes it keep track of its trainable parameters, you can swap it between CPU and GPU with the .cuda() or .cpu() functions, etc.\n",
667667
"\n",
668-
"Let's write an annotated example of a network that takes in a sparse bag-of-words representation and outputs a probability distribution over two labels: \"English\" and \"Spanish\".\n",
669-
"\n",
670-
"Note: This is just for demonstration, so that we can build Pytorch components in later sections and you will know what is going on. Handing in a sparse bag-of-words representation is not how you would actually want to do things. There are better ways to do text classification. I made up this model to be extremely simple and to not use word embeddings (which we don't introduce until the next section)."
668+
"Let's write an annotated example of a network that takes in a sparse bag-of-words representation and outputs a probability distribution over two labels: \"English\" and \"Spanish\". This model is just logistic regression."
671669
]
672670
},
673671
{
674672
"cell_type": "markdown",
675673
"metadata": {},
676674
"source": [
675+
"### Example: Logistic Regression Bag-of-Words classifier\n",
677676
"Our model will map a sparse BOW representation to log probabilities over labels. We assign each word in the vocab an index. For example, say our entire vocab is two words \"hello\" and \"world\", with indices 0 and 1 respectively.\n",
678677
"The BoW vector for the sentence \"hello hello hello hello\" is\n",
679678
"$$ \\left[ 4, 0 \\right] $$\n",
@@ -691,7 +690,7 @@
691690
},
692691
{
693692
"cell_type": "code",
694-
"execution_count": 89,
693+
"execution_count": 68,
695694
"metadata": {
696695
"collapsed": false
697696
},
@@ -728,7 +727,7 @@
728727
},
729728
{
730729
"cell_type": "code",
731-
"execution_count": 90,
730+
"execution_count": 69,
732731
"metadata": {
733732
"collapsed": false
734733
},
@@ -760,14 +759,14 @@
760759
},
761760
{
762761
"cell_type": "code",
763-
"execution_count": 91,
762+
"execution_count": 70,
764763
"metadata": {
765764
"collapsed": true
766765
},
767766
"outputs": [],
768767
"source": [
769768
"def make_bow_vector(sentence, word_to_ix):\n",
770-
" vec = torch.Tensor([0] * len(word_to_ix))\n",
769+
" vec = torch.zeros(len(word_to_ix))\n",
771770
" for word in sentence:\n",
772771
" vec[word_to_ix[word]] += 1\n",
773772
" return vec.view(1, -1)\n",
@@ -778,7 +777,7 @@
778777
},
779778
{
780779
"cell_type": "code",
781-
"execution_count": 92,
780+
"execution_count": 71,
782781
"metadata": {
783782
"collapsed": false
784783
},
@@ -790,21 +789,21 @@
790789
"Parameter containing:\n",
791790
"\n",
792791
"Columns 0 to 9 \n",
793-
" 0.0058 0.1147 0.1744 -0.1844 0.0339 0.1503 0.1582 0.0160 -0.1422 -0.0204\n",
794-
" 0.1503 0.0746 0.0485 0.0580 0.0984 -0.0573 -0.0593 0.1032 -0.0902 -0.0563\n",
792+
"-0.0650 0.0114 0.0729 0.0944 0.1127 -0.0445 -0.0844 -0.0186 -0.0314 0.1862\n",
793+
" 0.1257 -0.1513 0.0712 -0.0120 -0.0240 -0.0998 -0.0409 0.0621 0.1474 -0.1760\n",
795794
"\n",
796795
"Columns 10 to 19 \n",
797-
"-0.1415 0.1538 0.1206 -0.0480 -0.0401 0.0151 -0.1313 0.0597 0.1677 -0.0544\n",
798-
" 0.1553 0.0992 -0.0282 0.1496 0.1823 -0.1915 0.0641 -0.0007 0.0477 -0.1672\n",
796+
" 0.1500 0.0017 0.1036 -0.0334 -0.0844 -0.0557 0.0878 -0.0663 -0.1855 0.0172\n",
797+
"-0.1025 -0.1243 -0.0416 0.0733 -0.1398 -0.0462 0.0676 0.1198 -0.1272 -0.1517\n",
799798
"\n",
800799
"Columns 20 to 25 \n",
801-
"-0.0597 0.0279 0.0984 0.0541 0.0886 -0.1466\n",
802-
"-0.1511 0.1126 0.1763 -0.1710 -0.0196 -0.0568\n",
800+
" 0.0248 -0.1245 -0.1800 -0.1680 -0.1467 -0.0838\n",
801+
" 0.1960 -0.1035 0.1822 -0.0159 -0.1695 0.1666\n",
803802
"[torch.FloatTensor of size 2x26]\n",
804803
"\n",
805804
"Parameter containing:\n",
806-
" 0.0307\n",
807-
" 0.1733\n",
805+
"-0.1683\n",
806+
"-0.0214\n",
808807
"[torch.FloatTensor of size 2]\n",
809808
"\n"
810809
]
@@ -871,7 +870,10 @@
871870
"cell_type": "markdown",
872871
"metadata": {},
873872
"source": [
874-
"So lets train! To do this, we pass instances through to get log probabilities, compute a loss function, compute the gradient of the loss function, and then update the parameters with a gradient step. Loss functions are provided by Torch in the nn package. nn.NLLLoss() is the negative log likelihood loss we want. It also defines optimization functions in torch.optim. Here, we will just use SGD."
873+
"So lets train! To do this, we pass instances through to get log probabilities, compute a loss function, compute the gradient of the loss function, and then update the parameters with a gradient step. Loss functions are provided by Torch in the nn package. nn.NLLLoss() is the negative log likelihood loss we want. It also defines optimization functions in torch.optim. Here, we will just use SGD.\n",
874+
"\n",
875+
"Note that the *input* to NLLLoss is a vector of log probabilities, and a target label. It doesn't compute the log probabilities for us. This is why the last layer of our network is log softmax.\n",
876+
"The loss function nn.CrossEntropyLoss() is the same as NLLLoss(), except it does the log softmax for you."
875877
]
876878
},
877879
{
@@ -1193,10 +1195,9 @@
11931195
" \n",
11941196
" def forward(self, inputs):\n",
11951197
" embeds = self.embeddings(inputs).view((1, -1))\n",
1196-
" log_probs = F.log_softmax(\n",
1197-
" self.linear2(\n",
1198-
" F.relu(\n",
1199-
" self.linear1(embeds))))\n",
1198+
" out = F.relu(self.linear1(embeds))\n",
1199+
" out = self.linear2(out)\n",
1200+
" log_probs = F.log_softmax(out)\n",
12001201
" return log_probs"
12011202
]
12021203
},
@@ -1706,7 +1707,7 @@
17061707
"cell_type": "markdown",
17071708
"metadata": {},
17081709
"source": [
1709-
"# 8. Making Dynamic Decisions: Structure Prediction (Coming Soon!)"
1710+
"# 8. Advanced: Making Dynamic Decisions and the Bi-LSTM CRF (WIP)"
17101711
]
17111712
},
17121713
{
@@ -1742,11 +1743,156 @@
17421743
"cell_type": "markdown",
17431744
"metadata": {},
17441745
"source": [
1745-
"For this section, we will see a full, complicated example of a Bi-LSTM Conditional Random Field for named-entity recognition. Familiarity with CRF's is assumed. Although this name sounds scary, all the model is is a CRF but where an LSTM provides the features.\n",
1746+
"For this section, we will see a full, complicated example of a Bi-LSTM Conditional Random Field for named-entity recognition. Familiarity with CRF's is assumed. Although this name sounds scary, all the model is is a CRF but where an LSTM provides the features. This is an advanced model though, far more complicated than any earlier model in this tutorial. If you want to skip it, that is fine.\n",
1747+
"\n",
1748+
"TODO explain BiLSTM CRF Here"
1749+
]
1750+
},
1751+
{
1752+
"cell_type": "code",
1753+
"execution_count": 67,
1754+
"metadata": {
1755+
"collapsed": false
1756+
},
1757+
"outputs": [
1758+
{
1759+
"data": {
1760+
"text/plain": [
1761+
"(Variable containing:\n",
1762+
" 1.8765\n",
1763+
" [torch.FloatTensor of size 1], [2, 1, 2])"
1764+
]
1765+
},
1766+
"execution_count": 67,
1767+
"metadata": {},
1768+
"output_type": "execute_result"
1769+
}
1770+
],
1771+
"source": [
1772+
"# Work in progress. Needs extensive commenting but it runs.\n",
1773+
"\n",
1774+
"\n",
1775+
"def to_scalar(var):\n",
1776+
" return var.view(-1).data.tolist()[0]\n",
17461777
"\n",
1747-
"Let $\\textbf{y}$ be a tag sequence, and $\\textbf{w}$ a sequence of words. Recall that the CRF wants to compute\n",
1748-
"$$ P(\\textbf{y} | \\textbf{w}) = \\frac{ \\exp{ ( \\sum_i f(y_{i-1}, y_i, i, \\textbf{w}) \\cdot \\theta ) }}\n",
1749-
"{\\sum_{\\textbf{y'}} \\exp{ ( \\sum_j f(y'_{j-1}, y'_j, j, \\textbf{w} \\cdot \\theta } ) } $$"
1778+
"def argmax(vec):\n",
1779+
" _, idx = torch.max(vec, 1)\n",
1780+
" return to_scalar(idx)\n",
1781+
"\n",
1782+
"def log_sum_exp(vec):\n",
1783+
" max_score = vec[0][argmax(vec)]\n",
1784+
" max_score_broadcast = max_score.expand(vec.size()[1])\n",
1785+
" return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))\n",
1786+
" \n",
1787+
"\n",
1788+
"class BiLSTM_CRF(nn.Module):\n",
1789+
" \n",
1790+
" def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim):\n",
1791+
" super(BiLSTM_CRF, self).__init__()\n",
1792+
" self.embedding_dim = embedding_dim\n",
1793+
" self.hidden_dim = hidden_dim\n",
1794+
" self.vocab_size = vocab_size\n",
1795+
" self.tag_to_ix = tag_to_ix\n",
1796+
" self.tagset_size = len(tag_to_ix)\n",
1797+
" \n",
1798+
" self.word_embeds = nn.Embedding(vocab_size, embedding_dim)\n",
1799+
" self.lstm = nn.LSTM(embedding_dim, hidden_dim/2, num_layers=1, bidirectional=True)\n",
1800+
" self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)\n",
1801+
" self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))\n",
1802+
" \n",
1803+
" self.hidden = self.init_hidden()\n",
1804+
" \n",
1805+
" def init_hidden(self):\n",
1806+
" return ( autograd.Variable( torch.randn(2, 1, self.hidden_dim)),\n",
1807+
" autograd.Variable( torch.randn(2, 1, self.hidden_dim)) )\n",
1808+
" \n",
1809+
" \n",
1810+
" def _forward_alg(self, feats):\n",
1811+
" init_alphas = torch.Tensor(1, self.tagset_size).fill_(-10000.)\n",
1812+
" init_alphas[0][self.tag_to_ix[START_TAG]] = 0.\n",
1813+
" \n",
1814+
" forward_var = autograd.Variable(init_alphas)\n",
1815+
" \n",
1816+
" for feat in feats:\n",
1817+
" alphas_t = []\n",
1818+
" for next_tag in xrange(self.tagset_size):\n",
1819+
" emit_score = feat[next_tag].expand(self.tagset_size)\n",
1820+
" trans_score = self.transitions[next_tag]\n",
1821+
" next_tag_var = forward_var + trans_score + emit_score\n",
1822+
" alphas_t.append(log_sum_exp(next_tag_var))\n",
1823+
" forward_var = torch.cat(alphas_t).view(1, -1)\n",
1824+
" terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]\n",
1825+
" alpha = log_sum_exp(terminal_var)\n",
1826+
" return alpha\n",
1827+
" \n",
1828+
" def _get_lstm_features(self, sentence):\n",
1829+
" embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)\n",
1830+
" lstm_out, self.hidden = self.lstm(embeds)\n",
1831+
" lstm_out = lstm_out.view(len(sentence), self.hidden_dim)\n",
1832+
" lstm_feats = self.hidden2tag(lstm_out)\n",
1833+
" return lstm_feats\n",
1834+
" \n",
1835+
" def _score_sentence(self, feats, tags):\n",
1836+
" score = autograd.Variable( torch.Tensor([0]) )\n",
1837+
" tags = [self.tag_to_ix[START_TAG]] + tags\n",
1838+
" for i, feat in enumerate(feats):\n",
1839+
" score = score + self.transitions[tags[i+1], tags[i]] + feat[i+1]\n",
1840+
" score = score + self.transitions[self.tag_to_ix[STOP_TAG]][tags[-1]]\n",
1841+
" return score\n",
1842+
" \n",
1843+
" def _viterbi_decode(self, feats):\n",
1844+
" backpointers = []\n",
1845+
" init_vvars = torch.Tensor(1, self.tagset_size).fill_(-10000.)\n",
1846+
" init_vvars[0][self.tag_to_ix[START_TAG]] = 0\n",
1847+
" forward_var = autograd.Variable(init_vvars)\n",
1848+
" for feat in feats:\n",
1849+
" bptrs_t = []\n",
1850+
" viterbivars_t = []\n",
1851+
" \n",
1852+
" for next_tag in range(self.tagset_size):\n",
1853+
" next_tag_var = forward_var + self.transitions[next_tag]\n",
1854+
" best_tag_id = argmax(next_tag_var)\n",
1855+
" bptrs_t.append(best_tag_id)\n",
1856+
" viterbivars_t.append(next_tag_var[0][best_tag_id])\n",
1857+
" forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)\n",
1858+
" backpointers.append(bptrs_t)\n",
1859+
" \n",
1860+
" terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]\n",
1861+
" best_tag_id = argmax(terminal_var)\n",
1862+
" path_score = terminal_var[0][best_tag_id]\n",
1863+
" \n",
1864+
" best_path = [best_tag_id]\n",
1865+
" for bptrs_t in reversed(backpointers):\n",
1866+
" best_tag_id = bptrs_t[best_tag_id]\n",
1867+
" best_path.append(best_tag_id)\n",
1868+
" start = best_path.pop()\n",
1869+
" assert start == self.tag_to_ix[START_TAG]\n",
1870+
" best_path.reverse()\n",
1871+
" return best_path, path_score\n",
1872+
" \n",
1873+
" def log_likelihood(self, sentence, tags):\n",
1874+
" feats = self._get_lstm_features(sentence)\n",
1875+
" forward_score = self._forward_alg(feats)\n",
1876+
" gold_score = self._score_sentence(feats, tags)\n",
1877+
" return gold_score - forward_score\n",
1878+
" \n",
1879+
" def forward(self, sentence):\n",
1880+
" embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)\n",
1881+
" lstm_out, self.hidden = self.lstm(embeds)\n",
1882+
" lstm_out = lstm_out.view(len(sentence), self.hidden_dim)\n",
1883+
" lstm_feats = self.hidden2tag(lstm_out)\n",
1884+
" \n",
1885+
" tag_seq, score = self._viterbi_decode(lstm_feats)\n",
1886+
" return score, tag_seq\n",
1887+
"\n",
1888+
"START_TAG = \"<START>\"\n",
1889+
"STOP_TAG = \"<STOP>\"\n",
1890+
"tag_to_ix = { START_TAG: 0, STOP_TAG: 1, \"NN\": 2, \"V\": 3 }\n",
1891+
"word_to_ix = {\"hello\": 1, \"word\": 0}\n",
1892+
"sentence = \"hello hello word\".split()\n",
1893+
"idxs = autograd.Variable( torch.LongTensor(map(lambda w: word_to_ix[w], sentence) ))\n",
1894+
"model = BiLSTM_CRF(2, tag_to_ix, 4, 6)\n",
1895+
"model(idxs)\n"
17501896
]
17511897
},
17521898
{

0 commit comments

Comments
 (0)