pytorch -- a next generation tensor / deep learning framework.

While I do not like the idea of asking you to do an activity just to teach you a tool, I feel strongly about pytorch that I think you should know how to use it. I wish I had designed the course around pytorch but it was released just around the time we started this class. This optional lab will help you translate your deep learning skills from whatever framework you are using right now for your projects to pytorch (kudos if you are already working in pytorch for your project!). There are two main sources to learn more about pytorch:

  1. Their documentation http://pytorch.org/docs/
  2. Their examples directory https://github.com/pytorch/examples

Keep these links close to you, the second link has examples to a text generation example, single-label image classification example, generative adversarial network (GAN) example, among others.

1. Automatic differentiation in pytorch

Pytorch implements a tensor object just like keras and tensorflow, however unlike tensorflow these tensor objects actually contain values (they are not symbolic references), and the operations actually modify the data (they are not just defining a computation graph). This makes debugging and trying out things in pytorch much easier. Additionally, the tensors can be accessed/sliced using numpy-like operations since the authors of pytorch replicated much of numpy's functionality (but also the backward passes for most of them). Let's implement some layers in pytorch similar to the layers we implemented earlier in the class in our deep learning lab.

In [ ]:
import torch
from torch import nn
from torch.nn.parameter import Parameter
from torch.autograd import Variable

# Let's define a linear layer.
class nn_Linear(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(nn_Linear, self).__init__()
        # Create the layer parameters.
        self.weight = Parameter(torch.Tensor(output_dim, input_dim))
        self.bias = Parameter(torch.Tensor(1, output_dim))
        
        # intialize the weight and bias parameters using random values.
        self.weight.data.uniform_(-0.001, 0.001)  # Parameters have .data and .gradient values.
        self.bias.data.uniform_(-0.001, 0.001)    # Parameters have .data and .gradient values.
        
    # y = Wx + b
    def forward(self, x):
        # Here you could try to see what values or sizes have these inputs.
        # print(self.weight.size())
        # print(x.size())
        
        # Note that this type of debugging is not usually possible in tensorflow/keras because
        # in those frameworks these operations only define a computation graph but are not operating
        # directly on values.
        
        batch_expanded_bias = self.bias.expand(x.size(0), self.bias.size(1))
        return torch.addmm(1, batch_expanded_bias.t(), 1, self.weight, x.t()).t()
    

# Let's create an instance of nn_linear
linear = nn_Linear(4, 2)

# Let's define some input variable.
inputVar = Variable(torch.Tensor([[0.2, 0.3, -0.1, 0.2],
                                  [0.3, 0.1, 0.3, -0.4],
                                  [0.1, 0.2, 0.4, -0.4]]))

# Let's print some code output of the linear layer.
outputVar = linear(inputVar)
print(outputVar.data)   # This will contain y = Wx + b
print(outputVar.grad)   # This will contain dy, the gradient of the output after backpropagation.

# This is to show how pytorch's magic. It registers parameters so you can easily traverse them.
print([param.size() for param in linear.parameters()])

Notice the following things about the above code:

  1. Parameter are special variables that make sure tensors are registered as parameters of the module and get returned when calling module.parameters(). This is useful to do optimization by simply iterating over parameters and performing an SGD step.
    param.data.add_(-0.001 * param.grad.data)
    
  2. The above line takes us to the second observation. Variables (and Parameters) have two values, the actual value of the variable (data), and the gradient of the variable (grad). This allows us to find the gradients with respect to any variable that we want in our models including inputs, outputs and parameters since they all have to be variables.

Below I'm showing an example where we move the parameters of the linear layer using SGD to minimize a mean squared error loss function for a single dummy batch.

In [ ]:
class nn_MSECriterion(nn.Module):  # MSE = mean squared error.
    def forward(self, predictions, labels):
        return (predictions - labels).pow(2).sum()
    
inputs = Variable(torch.Tensor([[0.2, 0.3, -0.1, 0.2],
                               [0.3, 0.1, 0.3, -0.4],
                               [0.1, 0.2, 0.4, -0.4]]))

labels = Variable(torch.Tensor([[1, 1],
                                [2, 2],
                                [3, 3]]))

# Now optimize until the loss becomes small.
linear = nn_Linear(4, 2)
linear.train()  # Makes a difference when the module has dropout or batchnorm which behave different during testing.
for iteration in range(0, 50):
    predictions = linear(inputs) # forward pass.
    loss = nn_MSECriterion()(predictions, labels)  # loss function.
    loss.backward() # This backpropagates errors all-the-way.
    linear.weight.data.add_(-0.0001 * linear.weight.grad.data)  # SGD step.
    linear.bias.data.add_(-0.0001 * linear.bias.grad.data)  # SGD step.
    print(iteration, loss.data[0])
    

First thing to notice, we do not need to write a backward function to compute gradients, as long as we implement all the operations using pytorch we get the backward pass functionality for free. Also, in pytorch we do not need to implement basic functions such as nn_Linear since it already has all the basic layers (and some advanced ones) inside torch.nn (e.g. nn.Sequential, nn.Linear, nn.Conv2D, nn.ReLU, nn.Sigmoid), and torch.nn.functional (e.g. available as functions F.relu, F.sigmoid, etc which is convenient when the layer does not have parameters). However, you can implement new ones by creating your own module and overriding the forward function of the module using torch operations.

2. Convolutional Neural Networks in pytorch

Pytorch implements convolutional layers, and also has easy access to pretrained models (VGG, Resnet, etc) [http://pytorch.org/docs/torchvision/models.html]. So it is just as convenient as Keras or lua-torch.

In [178]:
import torchvision.models as models
alexnet = models.alexnet(pretrained = True)
print(alexnet)
AlexNet (
  (features): Sequential (
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU (inplace)
    (2): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1))
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU (inplace)
    (5): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1))
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU (inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU (inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU (inplace)
    (12): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1))
  )
  (classifier): Sequential (
    (0): Dropout (p = 0.5)
    (1): Linear (9216 -> 4096)
    (2): ReLU (inplace)
    (3): Dropout (p = 0.5)
    (4): Linear (4096 -> 4096)
    (5): ReLU (inplace)
    (6): Linear (4096 -> 1000)
  )
)

Now let's try the model on a sample image and show its predictions. Notice that we can directly use only the features part of this model if we use alexnet.features(image). This would output the activations of the last MaxPool2d layer, as opposed to the output of the last linear layer in the classifier part if we do alexenet(image).

In [274]:
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import json, string
%matplotlib inline

# 1. Define the appropriate image pre-processing function.
preprocessFn = transforms.Compose([transforms.Scale(256), 
                                   transforms.CenterCrop(224), 
                                   transforms.ToTensor(), 
                                   transforms.Normalize(mean = [0.485, 0.456, 0.406], 
                                                        std=[0.229, 0.224, 0.225])])

# 2. Load the imagenet class names.
imagenetClasses = {int(idx): entry[1] for (idx, entry) in json.load(open('imagenet_class_index.json')).items()}

# 3. Forward a test image of the toaster.
# Never forget to set in evaluation mode so Dropoff layers don't add randomness.
alexnet.eval()
# unsqueeze(0) adds a dummy batch dimension which is required for all models in pytorch.
image = Image.open('test_image.jpg').convert('RGB')
inputVar =  Variable(preprocessFn(image).unsqueeze(0))
predictions = alexnet(inputVar)

# 4. Decode the top 10 classes predicted for this image.
# We need to apply softmax because the model outputs the last linear layer activations and not softmax scores.
probs, indices = (-nn.Softmax()(predictions).data).sort()
probs = (-probs).numpy()[0][:10]; indices = indices.numpy()[0][:10]
preds = [imagenetClasses[idx] + ': ' + str(prob) for (prob, idx) in zip(probs, indices)]

# 5. Show image and predictions
plt.title(string.join(preds, '\n'))
plt.imshow(image);

2. Recurrent Neural Networks in pytorch

Pytorch implements recurrent neural networks, and unlike the current Keras/Tensorflow, there is no need to specify the length of the sequence, if you review the documentation of the RNN class in pytorch, the only variables are about the size of the hidden state and the output. However, in order to use batches where each element might have a different length size, we still need to pad batches but different batches can have different sizes.

In [290]:
# 1. Let's read this file with a bunch of text by Edgar Allan Poe's work.
# This file was compiled by Evan Otero in this repository: https://github.com/evanotero/edgar-alan-turing
text = open('poe.txt').read().lower()
print('corpus length:', len(text))
# Let's show a chunk of text.
print(text[11100:11500])

chars = sorted(list(set(text)))
print('total chars:', len(chars))
char_indices = dict((c, i) for i, c in enumerate(chars))
indices_char = dict((i, c) for i, c in enumerate(chars))
('corpus length:', 2715838)
he school. i assign, from my recollection, this place to
   howard. poe, as i recall my impressions now, was self-willed,
   capricious, inclined to be imperious, and, though of generous
   impulses, not steadily kind, nor even amiable; and so what he would
   exact was refused to him. i add another thing which had its influence,
   i am sure. at the time of which i speak, richmond was one of
('total chars:', 83)

Now we will build a text generation model as the one presented in Keras examples https://github.com/fchollet/keras/blob/master/examples/lstm_text_generation.py

In [344]:
import torch.nn.functional as F

class CharGenerator(nn.Module):
    def __init__(self):
        super(CharGenerator, self).__init__()
        self.nChars = len(chars) # One-hot encodings of characters.
        self.hiddenDim = 256
        
        self.rnn = nn.LSTM(self.nChars, self.hiddenDim, batch_first = True)
        self.decoder = nn.Linear(self.hiddenDim, self.nChars)
        
    def forward(self, charSequence):
        # LSTM signature is: allHiddenStates, (lastHiddenState, lastCellState) = rnn(inputSequence)
        # We only use the hidden state of the last time step.
        outputSequence, (lastHidden, lastCellState) = self.rnn(charSequence)
        # Then we need to pass that through another linear layer, and a softmax.
        # The view layer eliminates the sequence dimension which is not needed since we
        # are only dealing here with the hidden state of the last element in the sequence.
        return F.softmax(self.decoder(lastHidden.view(1, self.hiddenDim)))
    
    
# Testing the above model.
charSequence = 'we do not like workin until the last minut'
outputChar = 'e'

# One-hot encoding of the character sequence.
oneHotCharSequence = torch.zeros(1, len(charSequence), len(chars))
for (i, x) in enumerate(charSequence): oneHotCharSequence[0, i, char_indices[x]] = 1
    
# Pass the sequence through the model.
model = CharGenerator()
predictedCharacterProbs = model(Variable(oneHotCharSequence))
predictedCharacterProb, predictedCharacterId = predictedCharacterProbs.squeeze().data.max(0)

print(predictedCharacterProbs.size())
print(predictedCharacterProb[0])
print(indices_char[predictedCharacterId[0]])
torch.Size([1, 83])
0.013111914508
!

Lab Questions

This assignment adds points to a previous assignment where you lost the most amount of points (not the overall assignment grade, and it doesn't spill to other assignments). For instance, if you want to replace a 5pt assignment, then you only need to do 1) or 2) but not both as the extra 5pts will not spill to other assignments. If you want to replace a grade for a 10pt assignment you would need to do both 1) and 2). I expect that hardly anybody will really need to complete this lab but I encourage you to complete it if possible. However, your project and final project report are the real priority.

  1. Train a Convolutional Neural Network to predict the 80-categories of MS-COCO as in the visual recognition lab. Notice that torchvision.datasets includes a data loader for MS-COCO so try to reuse that but you will still need to modify it to have the label space as an 80-dimensional vector. You should also use Alexnet or some other pretrained network from torchvision.models as your starting source of features. a) Show some sample predictions demonstrating that the model learned successfully. b) Compute how many times in the validation set the top predicted label is correct and repor the number here.(5 pts) [Note: You might need to do things in the GPU for this task, find out using pytorch's documentation how to do this, you need to move your variables and models using the .cuda() method]

  2. Train the CharGenerator model presented in this lab using the poe text. It should be trained as in https://github.com/fchollet/keras/blob/master/examples/lstm_text_generation.py where the next character is predicted based on the previous 40 characters. Feel free to modify the model presented in this lab if needed. a) Report a loss history plot, and b) 5 samples of poe text using your trained model. (5 pts)

If you find any errors or omissions in this material please contact me at vicente@virginia.edu