{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2 - Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation\n",
    "\n",
    "In this second notebook on sequence-to-sequence models using PyTorch and TorchText, we'll be implementing the model from [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://arxiv.org/abs/1406.1078). This model will achieve improved test perplexity whilst only using a single layer RNN in both the encoder and the decoder.\n",
    "\n",
    "## Introduction\n",
    "\n",
    "Let's remind ourselves of the general encoder-decoder model.\n",
    "\n",
    "![](assets/seq2seq1.png)\n",
    "\n",
    "We use our encoder (green) over the source sequence to create a context vector (red). We then use that context vector with the decoder (blue) and a linear layer (purple) to generate the target sentence.\n",
    "\n",
    "In the previous model, we used an multi-layered LSTM as the encoder and decoder.\n",
    "\n",
    "![](assets/seq2seq4.png)\n",
    "\n",
    "One downside of the previous model is that the decoder is trying to cram lots of information into the hidden states. Whilst decoding, the hidden state will need to contain information about the whole of the source sequence, as well as all of the tokens have been decoded so far. By alleviating some of this information compression, we can create a better model!\n",
    "\n",
    "We'll also be using a GRU (Gated Recurrent Unit) instead of an LSTM (Long Short-Term Memory). Why? Mainly because that's what they did in the paper (this paper also introduced GRUs) and also because we used LSTMs last time. If you want to understand how GRUs (and LSTMs) differ from standard RNNS, check out [this](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) link. Is a GRU better than an LSTM? [Research](https://arxiv.org/abs/1412.3555) has shown they're pretty much the same, and both are better than standard RNNs. \n",
    "\n",
    "## Preparing Data\n",
    "\n",
    "All of the data preparation will be (almost) the same as last time, so I'll very briefly detail what each code block does. See the previous notebook if you've forgotten.\n",
    "\n",
    "We'll import PyTorch, TorchText, spaCy and a few standard modules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "from torchtext.datasets import TranslationDataset, Multi30k\n",
    "from torchtext.data import Field, BucketIterator\n",
    "\n",
    "import spacy\n",
    "\n",
    "import random\n",
    "import math\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then set a random seed for deterministic results/reproducability."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 1234\n",
    "\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instantiate our German and English spaCy models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "spacy_de = spacy.load('de')\n",
    "spacy_en = spacy.load('en')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Previously we reversed the source (German) sentence, however in the paper we are implementing they don't do this, so neither will we."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_de(text):\n",
    "    \"\"\"\n",
    "    Tokenizes German text from a string into a list of strings\n",
    "    \"\"\"\n",
    "    return [tok.text for tok in spacy_de.tokenizer(text)]\n",
    "\n",
    "def tokenize_en(text):\n",
    "    \"\"\"\n",
    "    Tokenizes English text from a string into a list of strings\n",
    "    \"\"\"\n",
    "    return [tok.text for tok in spacy_en.tokenizer(text)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create our fields to process our data. This will append the \"start of sentence\" and \"end of sentence\" tokens as well as converting all words to lowercase."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "SRC = Field(tokenize=tokenize_de, \n",
    "            init_token='<sos>', \n",
    "            eos_token='<eos>', \n",
    "            lower=True)\n",
    "\n",
    "TRG = Field(tokenize = tokenize_en, \n",
    "            init_token='<sos>', \n",
    "            eos_token='<eos>', \n",
    "            lower=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load our data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'), \n",
    "                                                    fields = (SRC, TRG))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll also print out an example just to double check they're not reversed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'src': ['zwei', 'junge', 'weiße', 'männer', 'sind', 'im', 'freien', 'in', 'der', 'nähe', 'vieler', 'büsche', '.'], 'trg': ['two', 'young', ',', 'white', 'males', 'are', 'outside', 'near', 'many', 'bushes', '.']}\n"
     ]
    }
   ],
   "source": [
    "print(vars(train_data.examples[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then create our vocabulary, converting all tokens appearing less than twice into `<unk>` tokens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "SRC.build_vocab(train_data, min_freq = 2)\n",
    "TRG.build_vocab(train_data, min_freq = 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, define the `device` and create our iterators."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 128\n",
    "\n",
    "train_iterator, valid_iterator, test_iterator = BucketIterator.splits(\n",
    "    (train_data, valid_data, test_data), \n",
    "    batch_size = BATCH_SIZE, \n",
    "    device = device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building the Seq2Seq Model\n",
    "\n",
    "### Encoder\n",
    "\n",
    "The encoder is similar to the previous one, with the multi-layer LSTM swapped for a single-layer GRU. We also don't pass the dropout as an argument to the GRU as that dropout is used between each layer of a multi-layered RNN. As we only have a single layer, PyTorch will display a warning if we try and use pass a dropout value to it.\n",
    "\n",
    "Another thing to note about the GRU is that it only requires and returns a hidden state, there is no cell state like in the LSTM.\n",
    "\n",
    "$$\\begin{align*}\n",
    "h_t &= \\text{GRU}(x_t, h_{t-1})\\\\\n",
    "(h_t, c_t) &= \\text{LSTM}(x_t, (h_{t-1}, c_{t-1}))\\\\\n",
    "h_t &= \\text{RNN}(x_t, h_{t-1})\n",
    "\\end{align*}$$\n",
    "\n",
    "From the equations above, it looks like the RNN and the GRU are identical. Inside the GRU, however, is a number of *gating mechanisms* that control the information flow in to and out of the hidden state (similar to an LSTM). Again, for more info, check out [this](https://colah.github.io/posts/2015-08-Understanding-LSTMs/) excellent post. \n",
    "\n",
    "The rest of the encoder should be very familar from the last tutorial, it takes in a sequence, $X = \\{x_1, x_2, ... , x_T\\}$, recurrently calculates hidden states, $H = \\{h_1, h_2, ..., h_T\\}$, and returns a context vector (the final hidden state), $z=h_T$.\n",
    "\n",
    "$$h_t = \\text{EncoderGRU}(x_t, h_{t-1})$$\n",
    "\n",
    "This is identical to the encoder of the general seq2seq model, with all the \"magic\" happening inside the GRU (green squares).\n",
    "\n",
    "![](assets/seq2seq5.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, input_dim, emb_dim, hid_dim, dropout):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.input_dim = input_dim\n",
    "        self.emb_dim = emb_dim\n",
    "        self.hid_dim = hid_dim\n",
    "        self.dropout = dropout\n",
    "        \n",
    "        self.embedding = nn.Embedding(input_dim, emb_dim) #no dropout as only one layer!\n",
    "        \n",
    "        self.rnn = nn.GRU(emb_dim, hid_dim)\n",
    "        \n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, src):\n",
    "        \n",
    "        #src = [src sent len, batch size]\n",
    "        \n",
    "        embedded = self.dropout(self.embedding(src))\n",
    "        \n",
    "        #embedded = [src sent len, batch size, emb dim]\n",
    "        \n",
    "        outputs, hidden = self.rnn(embedded) #no cell state!\n",
    "        \n",
    "        #outputs = [src sent len, batch size, hid dim * n directions]\n",
    "        #hidden = [n layers * n directions, batch size, hid dim]\n",
    "        \n",
    "        #outputs are always from the top hidden layer\n",
    "        \n",
    "        return hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decoder\n",
    "\n",
    "The decoder is where the implementation differs significantly from the previous model and we alleviate some of the information compression.\n",
    "\n",
    "Instead of the GRU in the decoder taking just the target token, $y_t$ and the previous hidden state $s_{t-1}$ as inputs, it also takes the context vector $z$. \n",
    "\n",
    "$$s_t = \\text{DecoderGRU}(y_t, s_{t-1}, z)$$\n",
    "\n",
    "Note how this context vector, $z$, does not have a $t$ subscript, meaning we re-use the same context vector returned by the encoder for every time-step in the decoder. \n",
    "\n",
    "Before, we predicted the next token, $\\hat{y}_{t+1}$, with the linear layer, $f$, only using the top-layer decoder hidden state at that time-step, $s_t$, as $\\hat{y}_{t+1}=f(s_t^L)$. Now, we also pass the current token, $\\hat{y}_t$ and the context vector, $z$ to the linear layer.\n",
    "\n",
    "$$\\hat{y}_{t+1} = f(y_t, s_t, z)$$\n",
    "\n",
    "Thus, our decoder now looks something like this:\n",
    "\n",
    "![](assets/seq2seq6.png)\n",
    "\n",
    "Note, the initial hidden state, $s_0$, is still the context vector, $z$, so when generating the first token we are actually inputting two identical context vectors into the GRU.\n",
    "\n",
    "How do these two changes reduce the information compression? Well, hypothetically the decoder hidden states, $s_t$, no longer need to contain information about the source sequence as it is always available as an input. Thus, it only needs to contain information about what tokens it has generated so far. The addition of $y_t$ to the linear layer also means this layer can directly see what the token is, without having to get this information from the hidden state. \n",
    "\n",
    "However, this hypothesis is just a hypothesis, it is impossible to determine how the model actually uses the information provided to it (don't listen to anyone that tells you differently). Nevertheless, it is a solid intuition and the results seem to indicate that this modifications are a good idea!\n",
    "\n",
    "Within the implementation, we will pass $y_t$ and $z$ to the GRU by concatenating them together, so the input dimensions to the GRU are now `emb_dim + hid_dim` (as context vector will be of size `hid_dim`). The linear layer will take $y_t, s_t$ and $z$ also by concatenating them together, hence the input dimensions are now `emb_dim + hid_dim*2`. We also don't pass a value of dropout to the GRU as it only uses a single layer.\n",
    "\n",
    "`forward` now takes a `context` argument. Inside of `forward`, we concatenate $y_t$ and $z$ as `emb_con` before feeding to the GRU, and we concatenate $y_t$, $s_t$ and $z$ together as `output` before feeding it through the linear layer to receive our predictions, $\\hat{y}_{t+1}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(nn.Module):\n",
    "    def __init__(self, output_dim, emb_dim, hid_dim, dropout):\n",
    "        super().__init__()\n",
    "\n",
    "        self.emb_dim = emb_dim\n",
    "        self.hid_dim = hid_dim\n",
    "        self.output_dim = output_dim\n",
    "        self.dropout = dropout\n",
    "        \n",
    "        self.embedding = nn.Embedding(output_dim, emb_dim)\n",
    "        \n",
    "        self.rnn = nn.GRU(emb_dim + hid_dim, hid_dim)\n",
    "        \n",
    "        self.out = nn.Linear(emb_dim + hid_dim * 2, output_dim)\n",
    "        \n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        \n",
    "    def forward(self, input, hidden, context):\n",
    "        \n",
    "        #input = [batch size]\n",
    "        #hidden = [n layers * n directions, batch size, hid dim]\n",
    "        #context = [n layers * n directions, batch size, hid dim]\n",
    "        \n",
    "        #n layers and n directions in the decoder will both always be 1, therefore:\n",
    "        #hidden = [1, batch size, hid dim]\n",
    "        #context = [1, batch size, hid dim]\n",
    "        \n",
    "        input = input.unsqueeze(0)\n",
    "        \n",
    "        #input = [1, batch size]\n",
    "        \n",
    "        embedded = self.dropout(self.embedding(input))\n",
    "        \n",
    "        #embedded = [1, batch size, emb dim]\n",
    "                \n",
    "        emb_con = torch.cat((embedded, context), dim = 2)\n",
    "            \n",
    "        #emb_con = [1, batch size, emb dim + hid dim]\n",
    "            \n",
    "        output, hidden = self.rnn(emb_con, hidden)\n",
    "        \n",
    "        #output = [sent len, batch size, hid dim * n directions]\n",
    "        #hidden = [n layers * n directions, batch size, hid dim]\n",
    "        \n",
    "        #sent len, n layers and n directions will always be 1 in the decoder, therefore:\n",
    "        #output = [1, batch size, hid dim]\n",
    "        #hidden = [1, batch size, hid dim]\n",
    "        \n",
    "        output = torch.cat((embedded.squeeze(0), hidden.squeeze(0), context.squeeze(0)), \n",
    "                           dim = 1)\n",
    "        \n",
    "        #output = [batch size, emb dim + hid dim * 2]\n",
    "        \n",
    "        prediction = self.out(output)\n",
    "        \n",
    "        #prediction = [batch size, output dim]\n",
    "        \n",
    "        return prediction, hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Seq2Seq Model\n",
    "\n",
    "Putting the encoder and decoder together, we get:\n",
    "\n",
    "![](assets/seq2seq7.png)\n",
    "\n",
    "Again, in this implementation we need to ensure the hidden dimensions in both the encoder and the decoder are the same.\n",
    "\n",
    "Briefly going over all of the steps:\n",
    "- the `outputs` tensor is created to hold all predictions, $\\hat{Y}$\n",
    "- the source sequence, $X$, is fed into the encoder to receive a `context` vector\n",
    "- the initial decoder hidden state is set to be the `context` vector, $s_0 = z = h_T$\n",
    "- we use a batch of `<sos>` tokens as the first `input`, $y_1$\n",
    "- we then decode within a loop:\n",
    "  - inserting the input token $y_t$, previous hidden state, $s_{t-1}$, and the context vector, $z$, into the decoder\n",
    "  - receiving a prediction, $\\hat{y}_{t+1}$, and a new hidden state, $s_t$\n",
    "  - we then decide if we are going to teacher force or not, setting the next input as appropriate (either the ground truth next token in the target sequence or the highest predicted next token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2Seq(nn.Module):\n",
    "    def __init__(self, encoder, decoder, device):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "        self.device = device\n",
    "        \n",
    "        assert encoder.hid_dim == decoder.hid_dim, \\\n",
    "            \"Hidden dimensions of encoder and decoder must be equal!\"\n",
    "        \n",
    "    def forward(self, src, trg, teacher_forcing_ratio = 0.5):\n",
    "        \n",
    "        #src = [src sent len, batch size]\n",
    "        #trg = [trg sent len, batch size]\n",
    "        #teacher_forcing_ratio is probability to use teacher forcing\n",
    "        #e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time\n",
    "        \n",
    "        batch_size = trg.shape[1]\n",
    "        max_len = trg.shape[0]\n",
    "        trg_vocab_size = self.decoder.output_dim\n",
    "        \n",
    "        #tensor to store decoder outputs\n",
    "        outputs = torch.zeros(max_len, batch_size, trg_vocab_size).to(self.device)\n",
    "        \n",
    "        #last hidden state of the encoder is the context\n",
    "        context = self.encoder(src)\n",
    "        \n",
    "        #context also used as the initial hidden state of the decoder\n",
    "        hidden = context\n",
    "        \n",
    "        #first input to the decoder is the <sos> tokens\n",
    "        input = trg[0,:]\n",
    "        \n",
    "        for t in range(1, max_len):\n",
    "            \n",
    "            #insert input token embedding, previous hidden state and the context state\n",
    "            #receive output tensor (predictions) and new hidden state\n",
    "            output, hidden = self.decoder(input, hidden, context)\n",
    "            \n",
    "            #place predictions in a tensor holding predictions for each token\n",
    "            outputs[t] = output\n",
    "            \n",
    "            #decide if we are going to use teacher forcing or not\n",
    "            teacher_force = random.random() < teacher_forcing_ratio\n",
    "            \n",
    "            #get the highest predicted token from our predictions\n",
    "            top1 = output.argmax(1) \n",
    "            \n",
    "            #if teacher forcing, use actual next token as next input\n",
    "            #if not, use predicted token\n",
    "            input = trg[t] if teacher_force else top1\n",
    "\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training the Seq2Seq Model\n",
    "\n",
    "The rest of this tutorial is very similar to the previous one. \n",
    "\n",
    "We initialise our encoder, decoder and seq2seq model (placing it on the GPU if we have one). As before, the embedding dimensions and the amount of dropout used can be different between the encoder and the decoder, but the hidden dimensions must remain the same."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "INPUT_DIM = len(SRC.vocab)\n",
    "OUTPUT_DIM = len(TRG.vocab)\n",
    "ENC_EMB_DIM = 256\n",
    "DEC_EMB_DIM = 256\n",
    "HID_DIM = 512\n",
    "ENC_DROPOUT = 0.5\n",
    "DEC_DROPOUT = 0.5\n",
    "\n",
    "enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, ENC_DROPOUT)\n",
    "dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, DEC_DROPOUT)\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "model = Seq2Seq(enc, dec, device).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we initialize our parameters. The paper states the parameters are initialized from a normal distribution with a mean of 0 and a standard deviation of 0.01, i.e. $\\mathcal{N}(0, 0.01)$. \n",
    "\n",
    "It also states we should initialize the recurrent parameters to a special initialization, however to keep things simple we'll also initialize them to $\\mathcal{N}(0, 0.01)$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Seq2Seq(\n",
       "  (encoder): Encoder(\n",
       "    (embedding): Embedding(7855, 256)\n",
       "    (rnn): GRU(256, 512)\n",
       "    (dropout): Dropout(p=0.5, inplace=False)\n",
       "  )\n",
       "  (decoder): Decoder(\n",
       "    (embedding): Embedding(5893, 256)\n",
       "    (rnn): GRU(768, 512)\n",
       "    (out): Linear(in_features=1280, out_features=5893, bias=True)\n",
       "    (dropout): Dropout(p=0.5, inplace=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def init_weights(m):\n",
    "    for name, param in m.named_parameters():\n",
    "        nn.init.normal_(param.data, mean=0, std=0.01)\n",
    "        \n",
    "model.apply(init_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We print out the number of parameters.\n",
    "\n",
    "Even though we only have a single layer RNN for our encoder and decoder we actually have **more** parameters  than the last model. This is due to the increased size of the inputs to the GRU and the linear layer. However, it is not a significant amount of parameters and causes a minimal amount of increase in training time (~3 seconds per epoch extra)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model has 14,220,293 trainable parameters\n"
     ]
    }
   ],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(f'The model has {count_parameters(model):,} trainable parameters')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We initiaize our optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(model.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also initialize the loss function, making sure to ignore the loss on `<pad>` tokens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "PAD_IDX = TRG.vocab.stoi['<pad>']\n",
    "\n",
    "criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then create the training loop..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, iterator, optimizer, criterion, clip):\n",
    "    \n",
    "    model.train()\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    \n",
    "    for i, batch in enumerate(iterator):\n",
    "        \n",
    "        src = batch.src\n",
    "        trg = batch.trg\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        output = model(src, trg)\n",
    "        \n",
    "        #trg = [trg sent len, batch size]\n",
    "        #output = [trg sent len, batch size, output dim]\n",
    "        \n",
    "        output = output[1:].view(-1, output.shape[-1])\n",
    "        trg = trg[1:].view(-1)\n",
    "        \n",
    "        #trg = [(trg sent len - 1) * batch size]\n",
    "        #output = [(trg sent len - 1) * batch size, output dim]\n",
    "        \n",
    "        loss = criterion(output, trg)\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        epoch_loss += loss.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "...and the evaluation loop, remembering to set the model to `eval` mode and turn off teaching forcing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, iterator, criterion):\n",
    "    \n",
    "    model.eval()\n",
    "    \n",
    "    epoch_loss = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "    \n",
    "        for i, batch in enumerate(iterator):\n",
    "\n",
    "            src = batch.src\n",
    "            trg = batch.trg\n",
    "\n",
    "            output = model(src, trg, 0) #turn off teacher forcing\n",
    "\n",
    "            #trg = [trg sent len, batch size]\n",
    "            #output = [trg sent len, batch size, output dim]\n",
    "\n",
    "            output = output[1:].view(-1, output.shape[-1])\n",
    "            trg = trg[1:].view(-1)\n",
    "\n",
    "            #trg = [(trg sent len - 1) * batch size]\n",
    "            #output = [(trg sent len - 1) * batch size, output dim]\n",
    "\n",
    "            loss = criterion(output, trg)\n",
    "\n",
    "            epoch_loss += loss.item()\n",
    "        \n",
    "    return epoch_loss / len(iterator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll also define the function that calculates how long an epoch takes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def epoch_time(start_time, end_time):\n",
    "    elapsed_time = end_time - start_time\n",
    "    elapsed_mins = int(elapsed_time / 60)\n",
    "    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
    "    return elapsed_mins, elapsed_secs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, we train our model, saving the parameters that give us the best validation loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 01 | Time: 0m 28s\n",
      "\tTrain Loss: 5.072 | Train PPL: 159.561\n",
      "\t Val. Loss: 5.065 |  Val. PPL: 158.309\n",
      "Epoch: 02 | Time: 0m 27s\n",
      "\tTrain Loss: 4.446 | Train PPL:  85.290\n",
      "\t Val. Loss: 5.474 |  Val. PPL: 238.400\n",
      "Epoch: 03 | Time: 0m 27s\n",
      "\tTrain Loss: 4.148 | Train PPL:  63.299\n",
      "\t Val. Loss: 4.868 |  Val. PPL: 130.008\n",
      "Epoch: 04 | Time: 0m 27s\n",
      "\tTrain Loss: 3.830 | Train PPL:  46.041\n",
      "\t Val. Loss: 4.617 |  Val. PPL: 101.215\n",
      "Epoch: 05 | Time: 0m 27s\n",
      "\tTrain Loss: 3.535 | Train PPL:  34.283\n",
      "\t Val. Loss: 4.411 |  Val. PPL:  82.338\n",
      "Epoch: 06 | Time: 0m 27s\n",
      "\tTrain Loss: 3.272 | Train PPL:  26.362\n",
      "\t Val. Loss: 4.143 |  Val. PPL:  62.964\n",
      "Epoch: 07 | Time: 0m 27s\n",
      "\tTrain Loss: 3.044 | Train PPL:  20.994\n",
      "\t Val. Loss: 3.923 |  Val. PPL:  50.562\n",
      "Epoch: 08 | Time: 0m 27s\n",
      "\tTrain Loss: 2.817 | Train PPL:  16.733\n",
      "\t Val. Loss: 3.872 |  Val. PPL:  48.041\n",
      "Epoch: 09 | Time: 0m 27s\n",
      "\tTrain Loss: 2.610 | Train PPL:  13.601\n",
      "\t Val. Loss: 3.777 |  Val. PPL:  43.706\n",
      "Epoch: 10 | Time: 0m 27s\n",
      "\tTrain Loss: 2.453 | Train PPL:  11.625\n",
      "\t Val. Loss: 3.732 |  Val. PPL:  41.772\n"
     ]
    }
   ],
   "source": [
    "N_EPOCHS = 10\n",
    "CLIP = 1\n",
    "\n",
    "best_valid_loss = float('inf')\n",
    "\n",
    "for epoch in range(N_EPOCHS):\n",
    "    \n",
    "    start_time = time.time()\n",
    "    \n",
    "    train_loss = train(model, train_iterator, optimizer, criterion, CLIP)\n",
    "    valid_loss = evaluate(model, valid_iterator, criterion)\n",
    "    \n",
    "    end_time = time.time()\n",
    "    \n",
    "    epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
    "    \n",
    "    if valid_loss < best_valid_loss:\n",
    "        best_valid_loss = valid_loss\n",
    "        torch.save(model.state_dict(), 'tut2-model.pt')\n",
    "    \n",
    "    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')\n",
    "    print(f'\\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')\n",
    "    print(f'\\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we test the model on the test set using these \"best\" parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| Test Loss: 3.672 | Test PPL:  39.311 |\n"
     ]
    }
   ],
   "source": [
    "model.load_state_dict(torch.load('tut2-model.pt'))\n",
    "\n",
    "test_loss = evaluate(model, test_iterator, criterion)\n",
    "\n",
    "print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just looking at the test loss, we get better performance. This is a pretty good sign that this model architecture is doing something right! Relieving the information compression seems like the way forard, and in the next tutorial we'll expand on this even further with *attention*."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}