2️⃣ Einops, Einsum & Tensor Manipulation
Learning Objectives
- Understand the basics of Einstein summation convention
- Learn how to use
einopsto perform basic tensor rearrangement, andeinsumto to perform standard linear algebra operations on tensors
Note - this section contains a large number of exercises. You should feel free to skim through them if you feel comfortable with the basic ideas.
Reading
- Read about the benefits of the
einopslibrary here. - If you haven't already, then review the Einops basics tutorial (up to the "fancy examples" section).
- Read einsum is all you need (or watch it) for a brief overview of the
einsumfunction and how it works. (You don't need to read past section 2.10.)
Setup
import math
import os
import sys
from pathlib import Path
import einops
import numpy as np
import torch as t
from torch import Tensor
# Make sure exercises are in the path
chapter = "chapter0_fundamentals"
section = "part0_prereqs"
root_dir = next(p for p in Path.cwd().parents if (p / chapter).exists())
exercises_dir = root_dir / chapter / "exercises"
section_dir = exercises_dir / section
if str(exercises_dir) not in sys.path:
sys.path.append(str(exercises_dir))
import part0_prereqs.tests as tests
from part0_prereqs.utils import display_array_as_img, display_soln_array_as_img
MAIN = __name__ == "__main__"
Einops
arr = np.load(section_dir / "numbers.npy")
arr is a 4D numpy array. The first axes corresponds to the number, and the next three axes are channels (i.e. RGB), height and width respectively. You have the function utils.display_array_as_img which takes in a numpy array and displays it as an image. There are two possible ways this function can be run:
- If the input is three-dimensional, the dimensions are interpreted as
(channel, height, width)- in other words, as an RGB image. - If the input is two-dimensional, the dimensions are interpreted as
(height, width)- i.e. a monochrome image.
For example:
print(arr[0].shape)
display_array_as_img(arr[0]) # plotting the first image in the batch
(3, 150, 150)
print(arr[0, 0].shape)
display_array_as_img(arr[0, 0]) # plotting the first channel of the first image, as monochrome
(150, 150)
arr_stacked = einops.rearrange(arr, "b c h w -> c h (b w)")
print(arr_stacked.shape)
display_array_as_img(arr_stacked) # plotting all images, stacked in a row
(3, 150, 900)
A series of images follow below, which have been created using einops functions performed on arr. You should work through these and try to produce each of the images yourself. This page also includes solutions, but you should only look at them after you've tried for at least five minutes.
Note - if you find you're comfortable with the first ~half of these, you can skip to later sections if you'd prefer, since these aren't particularly conceptually important.
Exercises - einops operations (match images)
```yaml Difficulty: 🔴🔴🔴⚪⚪ Importance: 🔵🔵⚪⚪⚪
You should spend up to ~45 minutes on these exercises collectively. If you think you get the general idea, then you can skip to the next section. You shouldn't spend longer than ~10 mins per exercise. ```
(1) Column-stacking
# Your code here - define arr1
display_array_as_img(arr1)
Solution
arr1 = einops.rearrange(arr, "b c h w -> c (b h) w")
(2) Column-stacking and copying
In this example we take just the first digit, and copy it along rows using einops.repeat.
# Your code here - define arr2
display_array_as_img(arr2)
Solution
arr2 = einops.repeat(arr[0], "c h w -> c (2 h) w")
(3) Row-stacking and double-copying
This example is pretty similar to the previous one, except that the part of the original image we need to slice and pass into einops.repeat also has a batch dimension of 2 (since it includes the first 2 digits).
# Your code here - define arr3
display_array_as_img(arr3)
Solution
arr3 = einops.repeat(arr[0:2], "b c h w -> c (b h) (2 w)")
(4) Stretching
The image below was stretched vertically by a factor of 2.
# Your code here - define arr4
display_array_as_img(arr4)
Solution
arr4 = einops.repeat(arr[0], "c h w -> c (h 2) w")
(5) Split channels
The image below was created by splitting out the 3 channels of the image (i.e. red, green, blue) and turning these into 3 stacked horizontal images. The output is 2D (the display function interprets this as a monochrome image).
# Your code here - define arr5
display_array_as_img(arr5)
Solution
arr5 = einops.rearrange(arr[0], "c h w -> h (c w)")
(6) Stack into rows & cols
This requires a rearrange operation with dimensions for row and column stacking.
# Your code here - define arr6
display_array_as_img(arr6)
Solution
arr6 = einops.rearrange(arr, "(b1 b2) c h w -> c (b1 h) (b2 w)", b1=2)
(7) Transpose
Here, we've just flipped the model's horizontal and vertical dimensions. Transposing is a fairly common tensor operation.
# Your code here - define arr7
display_array_as_img(arr7)
Solution
arr7 = einops.rearrange(arr[1], "c h w -> c w h")
(8) Shrinking
Hint - for this one, you should use max pooling - i.e. each pixel value in the output is the maximum of the corresponding 2x2 square in the original image.
# Your code here - define arr8
display_array_as_img(arr8)
Solution
arr8 = einops.reduce(arr, "(b1 b2) c (h h2) (w w2) -> c (b1 h) (b2 w)", "max", h2=2, w2=2, b1=2)
Broadcasting
Before we go through the next exercises, we'll need to address one important topic in tensor operations - broadcasting.
Both NumPy and PyTorch have the same rules for broadcasting. When two tensors are involved in an elementwise operation, NumPy/PyTorch tries to broadcast them (i.e. copying them along dimensions) so that they both have the same shape. The rules of broadcasting are as follows:
- You can prepend dummy dimensions (of size 1) to the start of a tensor until both have the same number of dimensions
- After this point, if some dimension has size 1 in one of the tensors, it can be repeated until it matches the size of the corresponding dimension in the other tensor
To give a simple example - suppose we have a 2D batch of data, of shape data.shape = (N, k) (i.e. we have N separate datapoints, each being a vector of length k). Suppose we want to add a vector vec of length k to each datapoint. This is a valid operation, because when we try and add these two objects together:
vecgets prepended with a dummy dimension so it has shape(1, k)and both are 2Dvecgets repeated along the first dimension so it has shape(N, k), matching the shape ofdata
Then, our output has shape (N, k), and elements output[i, j] = data[i, j] + vec[j].
Broadcasting can be a very easy place to make mistakes, because it's easy to lose track of the exact shape of your tensors involved. As a warm-up exercise, below are some examples of broadcasting. Can you figure out which are valid, and which will raise errors?
x = t.ones((3, 1, 5))
y = t.ones((1, 4, 5))
z = x + y
Answer
This is valid, because the 0th dimension of y and the 1st dimension of x can both be copied so that x and y have the same shape: (3, 4, 5). The resulting array z will also have shape (3, 4, 5).
This example illustrates an important point - it's not always the case that one of the tensors is strictly smaller and the other is strictly bigger. Sometimes, both tensors will get expanded.
x = t.ones((8, 2, 6))
y = t.ones((8, 2))
z = x + y
Answer
This is not valid. We first need to expand y by appending a dimension to the front, and the last two dimensions of x are (2, 6), which won't broadcast with y's (8, 2).
x = t.ones((8, 2, 6))
y = t.ones((2, 6))
z = x + y
Answer
This is valid. Once PyTorch expands y by appending a single dimension to the front, it can then be broadcast with x.
x = t.ones((10, 20, 30))
y = t.ones((20, 1))
z = x + y
Answer
This is valid. Once PyTorch expands y by appending a single dimension to the front, it can then be broadcast with x (this will involve copying along the first and last dimensions).
x = t.ones((4, 1))
y = t.ones((4,))
z = x + y
Answer
This is valid. Numpy will expand the second one to (1, 4) then broadcast them both to (4, 4).
Although this won't raise an error, it's very possible that this isn't what the person adding these two arrays intended. A common source of mistakes is when you add 2 tensors thinking they're the same shape, but one actually has a dummy dimension you forgot about. Sadly this is something you'll just have to be vigilant for (e.g. adding assert statements where necessary, or making sure you aren't combining too many different tensor operations in a single line), because PyTorch doesn't have built-in ways of statically checking your tensor shapes.
Einops is a useful tool for reshaping tensors to enable broadcasting. If you just need to add or remove a dummy dimension, you don't need to use einops: tensor.unsqueeze(dim) will give you a new tensor with a dummy dimension of size 1 inserted at position dim in the new tensor, and tensor.squeeze(dim) will give you a tensor with the dimension at position dim removed (if it had size 1, otherwise nothing happens).
x = t.ones((3, 1, 5))
print(x.unsqueeze(3).shape) # (3, 1, 5, 1) because we add a new dummy dimension at idx 3 (the end) in the new tensor
print(x.squeeze(1).shape) # (3, 5) because we remove the dimension at idx 1 (it has size 1)
print(x.squeeze(0).shape) # (3, 1, 5) because we don't remove the leading dimension (it has size 3)
Exercises - einops operations & broadcasting
```yaml Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵⚪⚪
You should spend up to ~45 minutes on these exercises collectively. These are more representative of the kinds of einops operations you'll use in practice. ```
Next, we have a series of functions which you should implement using einops. In the dropdown below these exercises, you can find solutions to all of them.
First, let's define some functions to help us test our solutions:
def assert_all_equal(actual: Tensor, expected: Tensor) -> None:
assert actual.shape == expected.shape, f"Shape mismatch, got: {actual.shape}"
assert (actual == expected).all(), f"Value mismatch, got: {actual}"
print("Tests passed!")
def assert_all_close(actual: Tensor, expected: Tensor, atol=1e-3) -> None:
assert actual.shape == expected.shape, f"Shape mismatch, got: {actual.shape}"
t.testing.assert_close(actual, expected, atol=atol, rtol=0.0)
print("Tests passed!")
(A1) rearrange
We'll start with a simple rearrange operation - you're asked to return a particular tensor using only t.arange and einops.rearrange. The t.arange function is similar to the numpy equivalent: torch.arange(start, end) will return a 1D tensor containing all the values from start to end - 1 inclusive.
def rearrange_1() -> Tensor:
"""Return the following tensor using only t.arange and einops.rearrange:
[[3, 4],
[5, 6],
[7, 8]]
"""
raise NotImplementedError()
expected = t.tensor([[3, 4], [5, 6], [7, 8]])
assert_all_equal(rearrange_1(), expected)
Solution
def rearrange_1() -> Tensor:
"""Return the following tensor using only t.arange and einops.rearrange:
[[3, 4],
[5, 6],
[7, 8]]
"""
return einops.rearrange(t.arange(3, 9), "(h w) -> h w", h=3, w=2)
(A2) rearrange
This exercise has the same pattern as the previous one.
def rearrange_2() -> Tensor:
"""Return the following tensor using only t.arange and einops.rearrange:
[[1, 2, 3],
[4, 5, 6]]
"""
raise NotImplementedError()
assert_all_equal(rearrange_2(), t.tensor([[1, 2, 3], [4, 5, 6]]))
Solution
def rearrange_2() -> Tensor:
"""Return the following tensor using only t.arange and einops.rearrange:
[[1, 2, 3],
[4, 5, 6]]
"""
return einops.rearrange(t.arange(1, 7), "(h w) -> h w", h=2, w=3)
(B1) temperature average
Here you're given a 1D tensor containing temperatures for each day. You should return a 1D tensor containing the average temperature for each week.
This could be done in 2 separate operations (a reshape from 1D to 2D followed by taking the mean over one of the axes), however we encourage you to try and find a solution with einops.reduce in just a single line.
def temperatures_average(temps: Tensor) -> Tensor:
"""Return the average temperature for each week.
temps: a 1D temperature containing temperatures for each day.
Length will be a multiple of 7 and the first 7 days are for the first week, second 7 days for the second week, etc.
You can do this with a single call to reduce.
"""
assert len(temps) % 7 == 0
raise NotImplementedError()
temps = t.tensor([71, 72, 70, 75, 71, 72, 70, 75, 80, 85, 80, 78, 72, 83]).float()
expected = [71.571, 79.0]
assert_all_close(temperatures_average(temps), t.tensor(expected))
Solution
def temperatures_average(temps: Tensor) -> Tensor:
"""Return the average temperature for each week.
temps: a 1D temperature containing temperatures for each day.
Length will be a multiple of 7 and the first 7 days are for the first week, second 7 days for the second week, etc.
You can do this with a single call to reduce.
"""
assert len(temps) % 7 == 0
return einops.reduce(temps, "(h 7) -> h", "mean")
(B2) temperature difference
Here, we're asking you to subtract the average temperature from each week from the daily temperatures. You'll have to be careful of broadcasting here, since your temperatures tensor has shape (14,) while your average temperature computed above has shape (2,) - these are not broadcastable.
def temperatures_differences(temps: Tensor) -> Tensor:
"""For each day, subtract the average for the week the day belongs to.
temps: as above
"""
assert len(temps) % 7 == 0
raise NotImplementedError()
expected = [-0.571, 0.429, -1.571, 3.429, -0.571, 0.429, -1.571, -4.0, 1.0, 6.0, 1.0, -1.0, -7.0, 4.0]
actual = temperatures_differences(temps)
assert_all_close(actual, t.tensor(expected))
Hint - how to make the tensors broadcastable
Your averages tensor from earlier was (hopefully) computed using einops.reduce(temps, "(h 7) -> h", "mean"), giving it shape (2,). You have 2 options:
- Repeat this averages tensor back to shape (14,) so you can subtract it from temps
- Rehsape temps tensor to have shape (7, 2) which would make it broadcastable with (2,), then flatten it back to (14,) after subtracting the average
Solution
def temperatures_differences(temps: Tensor) -> Tensor:
"""For each day, subtract the average for the week the day belongs to.
temps: as above
"""
assert len(temps) % 7 == 0
avg = einops.reduce(temps, "(w 7) -> w", "mean")
return temps - einops.repeat(avg, "w -> (w 7)")
(B3) temperature normalized
Lastly, you need to subtract the average and divide by the standard deviation. Note that you can pass t.std into the einops.reduce function to return the std dev of the values you're reducing over.
def temperatures_normalized(temps: Tensor) -> Tensor:
"""For each day, subtract the weekly average and divide by the weekly standard deviation.
temps: as above
Pass t.std to reduce.
"""
raise NotImplementedError()
expected = [-0.333, 0.249, -0.915, 1.995, -0.333, 0.249, -0.915, -0.894, 0.224, 1.342, 0.224, -0.224, -1.565, 0.894]
actual = temperatures_normalized(temps)
assert_all_close(actual, t.tensor(expected))
Solution
def temperatures_normalized(temps: Tensor) -> Tensor:
"""For each day, subtract the weekly average and divide by the weekly standard deviation.
temps: as above
Pass t.std to reduce.
"""
avg = einops.reduce(temps, "(w 7) -> w", "mean")
std = einops.reduce(temps, "(h 7) -> h", t.std)
return (temps - einops.repeat(avg, "w -> (w 7)")) / einops.repeat(std, "w -> (w 7)")
(C1) normalize a matrix
Here, we're asking you to normalize the rows of a matrix so that each row has L2 norm (sum of squared values) equal to 1. Note - L2 norm and standard deviation are not the same thing; L2 norm leaves out the averaging over size of vector step. We recommend you try and use the torch function t.norm directly rather than einops for this task.
Two useful things we should highlight here:
- Most PyTorch functions like
t.norm(tensor, ...)which operate on a single tensor are also tensor methods, i.e. they can be used astensor.norm(...)with the same arguments. - Most PyTorch dimensionality-reducing functions have the
keepdimargument, which is false by default but will cause your output tensor to keep dummy dimensions if set to true. For example, iftensorhas shape(3, 4)thentensor.sum(dim=1)has shape(3,)buttensor.sum(dim=1, keepdim=True)has shape(3, 1).
def normalize_rows(matrix: Tensor) -> Tensor:
"""Normalize each row of the given 2D matrix.
matrix: a 2D tensor of shape (m, n).
Returns: a tensor of the same shape where each row is divided by its l2 norm.
"""
raise NotImplementedError()
matrix = t.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).float()
expected = t.tensor([[0.267, 0.535, 0.802], [0.456, 0.570, 0.684], [0.503, 0.574, 0.646]])
assert_all_close(normalize_rows(matrix), expected)
Solution
def normalize_rows(matrix: Tensor) -> Tensor:
"""Normalize each row of the given 2D matrix.
matrix: a 2D tensor of shape (m, n).
Returns: a tensor of the same shape where each row is divided by its l2 norm.
"""
row_norms = matrix.norm(dim=1, keepdim=True) # keepdim=True
return matrix / row_norms
(C2) pairwise cosine similarity
Now, you should compute a matrix of shape (m, m) where out[i, j] is the cosine similarity between the i-th and j-th rows of matrix.
The cosine similarity between two vectors is given by summing the elementwise products of the normalized vectors. We haven't covered einsum yet, but you should be able to get this working using normal elementwise multiplication and summing (or even do this in one step - can you see how?).
def cos_sim_matrix(matrix: Tensor) -> Tensor:
"""Return the cosine similarity matrix for each pair of rows of the given matrix.
matrix: shape (m, n)
"""
raise NotImplementedError()
matrix = t.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).float()
expected = t.tensor([[1.0, 0.975, 0.959], [0.975, 1.0, 0.998], [0.959, 0.998, 1.0]])
assert_all_close(cos_sim_matrix(matrix), expected)
Solution
def cos_sim_matrix(matrix: Tensor) -> Tensor:
"""Return the cosine similarity matrix for each pair of rows of the given matrix.
matrix: shape (m, n)
"""
matrix_normalized = normalize_rows(matrix)
return matrix_normalized @ matrix_normalized.T
The reason this solution works is that matrix_normalized @ matrix_normalized.T multiplies along the columns of matrix_normalized and the rows of matrix_normalized.T then sums the output - in other words, it computes the dot products!
For a more verbose solution using explicit broadcasting, you could also do this:
left_matrix = matrix_normalized.unsqueeze(-1) # shape (m, n, 1)
right_matrix = matrix_normalized.T.unsqueeze(0) # shape (1, n, m)
products = left_matrix * right_matrix # shape (m, n, m)
return products.sum(dim=1) # shape (m, m)
(D) sample distribution
Here we're having you do something a bit more practical and less artificial. You're given a probability distribution (i.e. a tensor of probabilities that sum to 1) and asked to sample from it.
Hint - you can use the torch functions t.rand and t.cumsum to do this without any explicit loops.
def sample_distribution(probs: Tensor, n: int) -> Tensor:
"""Return n random samples from probs, where probs is a normalized probability distribution.
probs: shape (k,) where probs[i] is the probability of event i occurring.
n: number of random samples
Return: shape (n,) where out[i] is an integer indicating which event was sampled.
Use t.rand and t.cumsum to do this without any explicit loops.
"""
raise NotImplementedError()
n = 5_000_000
probs = t.tensor([0.05, 0.1, 0.1, 0.2, 0.15, 0.4])
freqs = t.bincount(sample_distribution(probs, n)) / n
assert_all_close(freqs, probs)
Help - I'm not sure how to use t.rand and t.cumsum in this way.
Suppose our probabilities were probs = [0.2, 0.3, 0.5]. Then the cumsum is probs_cumsum = [0.2, 0.5, 1.0]. If we generate a random value v between 0 and 1, then sum the boolean tensor v > probs_cumsum, it will be 0 exactly 0.2 of the time, 1 exactly 0.3 of the time, and 2 exactly 0.5 of the time - exactly what we want.
Solution
def sample_distribution(probs: Tensor, n: int) -> Tensor:
"""Return n random samples from probs, where probs is a normalized probability distribution.
probs: shape (k,) where probs[i] is the probability of event i occurring.
n: number of random samples
Return: shape (n,) where out[i] is an integer indicating which event was sampled.
Use t.rand and t.cumsum to do this without any explicit loops.
"""
assert abs(probs.sum() - 1.0) < 0.001
assert (probs >= 0).all()
return (t.rand(n, 1) > t.cumsum(probs, dim=0)).sum(dim=-1)
(E) classifier accuracy
Here, we're asking you to compute the accuracy of a classifier. scores is a tensor of shape (batch, n_classes) where scores[b, i] is the score the classifier gave to class i for input b, and true_classes is a tensor of shape (batch,) where true_classes[b] is the true class for input b. We want you to return the fraction of times the maximum score is equal to the true class.
You can use the torch function t.argmax, it works as follows: tensor.argmax(dim) will return a tensor of the index containing the maximum value along the dimension dim (i.e. the shape of this output will be the same as the shape of tensor except for the dimension dim).
def classifier_accuracy(scores: Tensor, true_classes: Tensor) -> Tensor:
"""Return the fraction of inputs for which the maximum score corresponds to the true class for that input.
scores: shape (batch, n_classes). A higher score[b, i] means that the classifier thinks class i is more likely.
true_classes: shape (batch, ). true_classes[b] is an integer from [0...n_classes).
Use t.argmax.
"""
raise NotImplementedError()
scores = t.tensor([[0.75, 0.5, 0.25], [0.1, 0.5, 0.4], [0.1, 0.7, 0.2]])
true_classes = t.tensor([0, 1, 0])
expected = 2.0 / 3.0
assert classifier_accuracy(scores, true_classes) == expected
print("Tests passed!")
Solution
def classifier_accuracy(scores: Tensor, true_classes: Tensor) -> Tensor:
"""Return the fraction of inputs for which the maximum score corresponds to the true class for that input.
scores: shape (batch, n_classes). A higher score[b, i] means that the classifier thinks class i is more likely.
true_classes: shape (batch, ). true_classes[b] is an integer from [0...n_classes).
Use t.argmax.
"""
assert true_classes.max() < scores.shape[1]
return (scores.argmax(dim=1) == true_classes).float().mean()
(F1) total price indexing
The next few exercises involve indexing, often using the torch.gather function. You can read about it in the docs here.
If you find gather confusing, an alternative is the eindex library, which was designed to provide indexing features motivated by how einops works. You can read more about that here, and as a suggested bonus you can try implementing / rewriting the next few functions using eindex.
def total_price_indexing(prices: Tensor, items: Tensor) -> float:
"""Given prices for each kind of item and a tensor of items purchased, return the total price.
prices: shape (k, ). prices[i] is the price of the ith item.
items: shape (n, ). A 1D tensor where each value is an item index from [0..k).
Use integer array indexing. The below document describes this for NumPy but it's the same in PyTorch:
https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing
"""
raise NotImplementedError()
prices = t.tensor([0.5, 1, 1.5, 2, 2.5])
items = t.tensor([0, 0, 1, 1, 4, 3, 2])
assert total_price_indexing(prices, items) == 9.0
print("Tests passed!")
Solution
def total_price_indexing(prices: Tensor, items: Tensor) -> float:
"""Given prices for each kind of item and a tensor of items purchased, return the total price.
prices: shape (k, ). prices[i] is the price of the ith item.
items: shape (n, ). A 1D tensor where each value is an item index from [0..k).
Use integer array indexing. The below document describes this for NumPy but it's the same in PyTorch:
https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing
"""
assert items.max() < prices.shape[0]
return prices[items].sum().item()
(F2) gather 2D
def gather_2d(matrix: Tensor, indexes: Tensor) -> Tensor:
"""Perform a gather operation along the second dimension.
matrix: shape (m, n)
indexes: shape (m, k)
Return: shape (m, k). out[i][j] = matrix[i][indexes[i][j]]
For this problem, the test already passes and it's your job to write at least three asserts relating the arguments and the output. This is a tricky function and worth spending some time to wrap your head around its behavior.
See: https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather
"""
# YOUR CODE HERE - add assert statement(s) here for `indices` and `matrix`
out = matrix.gather(1, indexes)
# YOUR CODE HERE - add assert statement(s) here for `out`
return out
matrix = t.arange(15).view(3, 5)
indexes = t.tensor([[4], [3], [2]])
expected = t.tensor([[4], [8], [12]])
assert_all_equal(gather_2d(matrix, indexes), expected)
indexes2 = t.tensor([[2, 4], [1, 3], [0, 2]])
expected2 = t.tensor([[2, 4], [6, 8], [10, 12]])
assert_all_equal(gather_2d(matrix, indexes2), expected2)
Solution
def gather_2d(matrix: Tensor, indexes: Tensor) -> Tensor:
"""Perform a gather operation along the second dimension.
matrix: shape (m, n)
indexes: shape (m, k)
Return: shape (m, k). out[i][j] = matrix[i][indexes[i][j]]
For this problem, the test already passes and it's your job to write at least three asserts relating the arguments and the output. This is a tricky function and worth spending some time to wrap your head around its behavior.
See: https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=gather#torch.gather
"""
assert matrix.ndim == indexes.ndim
assert indexes.shape[0] <= matrix.shape[0]
out = matrix.gather(1, indexes)
assert out.shape == indexes.shape
return out
(F3) total price gather
def total_price_gather(prices: Tensor, items: Tensor) -> float:
"""Compute the same as total_price_indexing, but use torch.gather."""
assert items.max() < prices.shape[0]
raise NotImplementedError()
prices = t.tensor([0.5, 1, 1.5, 2, 2.5])
items = t.tensor([0, 0, 1, 1, 4, 3, 2])
assert total_price_gather(prices, items) == 9.0
print("Tests passed!")
Solution
def total_price_gather(prices: Tensor, items: Tensor) -> float:
"""Compute the same as total_price_indexing, but use torch.gather."""
assert items.max() < prices.shape[0]
return prices.gather(0, items).sum().item()
(G) indexing
def integer_array_indexing(matrix: Tensor, coords: Tensor) -> Tensor:
"""Return the values at each coordinate using integer array indexing.
For details on integer array indexing, see:
https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing
matrix: shape (d_0, d_1, ..., d_n)
coords: shape (batch, n)
Return: (batch, )
"""
raise NotImplementedError()
mat_2d = t.arange(15).view(3, 5)
coords_2d = t.tensor([[0, 1], [0, 4], [1, 4]])
actual = integer_array_indexing(mat_2d, coords_2d)
assert_all_equal(actual, t.tensor([1, 4, 9]))
mat_3d = t.arange(2 * 3 * 4).view((2, 3, 4))
coords_3d = t.tensor([[0, 0, 0], [0, 1, 1], [0, 2, 2], [1, 0, 3], [1, 2, 0]])
actual = integer_array_indexing(mat_3d, coords_3d)
assert_all_equal(actual, t.tensor([0, 5, 10, 15, 20]))
Solution
def integer_array_indexing(matrix: Tensor, coords: Tensor) -> Tensor:
"""Return the values at each coordinate using integer array indexing.
For details on integer array indexing, see:
https://numpy.org/doc/stable/user/basics.indexing.html#integer-array-indexing
matrix: shape (d_0, d_1, ..., d_n)
coords: shape (batch, n)
Return: (batch, )
"""
return matrix[tuple(coords.T)]
(H1) batched logsumexp
def batched_logsumexp(matrix: Tensor) -> Tensor:
"""For each row of the matrix, compute log(sum(exp(row))) in a numerically stable way.
matrix: shape (batch, n)
Return: (batch, ). For each i, out[i] = log(sum(exp(matrix[i]))).
Do this without using PyTorch's logsumexp function.
A couple useful blogs about this function:
- https://leimao.github.io/blog/LogSumExp/
- https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
"""
raise NotImplementedError()
matrix = t.tensor([[-1000, -1000, -1000, -1000], [1000, 1000, 1000, 1000]])
expected = t.tensor([-1000 + math.log(4), 1000 + math.log(4)])
actual = batched_logsumexp(matrix)
assert_all_close(actual, expected)
matrix2 = t.randn((10, 20))
expected2 = t.logsumexp(matrix2, dim=-1)
actual2 = batched_logsumexp(matrix2)
assert_all_close(actual2, expected2)
Solution
def batched_logsumexp(matrix: Tensor) -> Tensor:
"""For each row of the matrix, compute log(sum(exp(row))) in a numerically stable way.
matrix: shape (batch, n)
Return: (batch, ). For each i, out[i] = log(sum(exp(matrix[i]))).
Do this without using PyTorch's logsumexp function.
A couple useful blogs about this function:
- https://leimao.github.io/blog/LogSumExp/
- https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/
"""
C = matrix.max(dim=-1).values
exps = t.exp(matrix - einops.rearrange(C, "n -> n 1"))
return C + t.log(t.sum(exps, dim=-1))
(H2) batched softmax
def batched_softmax(matrix: Tensor) -> Tensor:
"""For each row of the matrix, compute softmax(row).
Do this without using PyTorch's softmax function.
Instead, use the definition of softmax: https://en.wikipedia.org/wiki/Softmax_function
matrix: shape (batch, n)
Return: (batch, n). For each i, out[i] should sum to 1.
"""
raise NotImplementedError()
matrix = t.arange(1, 6).view((1, 5)).float().log()
expected = t.arange(1, 6).view((1, 5)) / 15.0
actual = batched_softmax(matrix)
assert_all_close(actual, expected)
for i in [0.12, 3.4, -5, 6.7]:
assert_all_close(actual, batched_softmax(matrix + i)) # check it's translation-invariant
matrix2 = t.rand((10, 20))
actual2 = batched_softmax(matrix2)
assert actual2.min() >= 0.0
assert actual2.max() <= 1.0
assert_all_equal(actual2.argsort(), matrix2.argsort())
assert_all_close(actual2.sum(dim=-1), t.ones(matrix2.shape[:-1]))
Solution
def batched_softmax(matrix: Tensor) -> Tensor:
"""For each row of the matrix, compute softmax(row).
Do this without using PyTorch's softmax function.
Instead, use the definition of softmax: https://en.wikipedia.org/wiki/Softmax_function
matrix: shape (batch, n)
Return: (batch, n). For each i, out[i] should sum to 1.
"""
exp = matrix.exp()
return exp / exp.sum(dim=-1, keepdim=True)
(H3) batched logsoftmax
def batched_logsoftmax(matrix: Tensor) -> Tensor:
"""Compute log(softmax(row)) for each row of the matrix.
matrix: shape (batch, n)
Return: (batch, n).
Do this without using PyTorch's logsoftmax function.
For each row, subtract the maximum first to avoid overflow if the row contains large values.
"""
raise NotImplementedError()
matrix = t.arange(1, 7).view((2, 3)).float()
start = 1000
matrix2 = t.arange(start + 1, start + 7).view((2, 3)).float()
actual = batched_logsoftmax(matrix2)
expected = t.tensor([[-2.4076, -1.4076, -0.4076],
[-2.4076, -1.4076, -0.4076]])
assert_all_close(actual, expected)
Solution
def batched_logsoftmax(matrix: Tensor) -> Tensor:
"""Compute log(softmax(row)) for each row of the matrix.
matrix: shape (batch, n)
Return: (batch, n).
Do this without using PyTorch's logsoftmax function.
For each row, subtract the maximum first to avoid overflow if the row contains large values.
"""
C = matrix.max(dim=1, keepdim=True).values
return matrix - C - (matrix - C).exp().sum(dim=1, keepdim=True).log()
(H4) batched cross entropy loss
def batched_cross_entropy_loss(logits: Tensor, true_labels: Tensor) -> Tensor:
"""Compute the cross entropy loss for each example in the batch.
logits: shape (batch, classes). logits[i][j] is the unnormalized prediction for example i and class j.
true_labels: shape (batch, ). true_labels[i] is an integer index representing the true class for example i.
Return: shape (batch, ). out[i] is the loss for example i.
Hint: convert the logits to log-probabilities using your batched_logsoftmax from above.
Then the loss for an example is just the negative of the log-probability that the model assigned to the true class. Use torch.gather to perform the indexing.
"""
assert logits.shape[0] == true_labels.shape[0]
assert true_labels.max() < logits.shape[1]
raise NotImplementedError()
logits = t.tensor([[float("-inf"), float("-inf"), 0], [1 / 3, 1 / 3, 1 / 3], [float("-inf"), 0, 0]])
true_labels = t.tensor([2, 0, 0])
expected = t.tensor([0.0, math.log(3), float("inf")])
actual = batched_cross_entropy_loss(logits, true_labels)
assert_all_close(actual, expected)
Solution
def batched_cross_entropy_loss(logits: Tensor, true_labels: Tensor) -> Tensor:
"""Compute the cross entropy loss for each example in the batch.
logits: shape (batch, classes). logits[i][j] is the unnormalized prediction for example i and class j.
true_labels: shape (batch, ). true_labels[i] is an integer index representing the true class for example i.
Return: shape (batch, ). out[i] is the loss for example i.
Hint: convert the logits to log-probabilities using your batched_logsoftmax from above.
Then the loss for an example is just the negative of the log-probability that the model assigned to the true class. Use torch.gather to perform the indexing.
"""
assert logits.shape[0] == true_labels.shape[0]
assert true_labels.max() < logits.shape[1]
logprobs = batched_logsoftmax(logits)
indices = einops.rearrange(true_labels, "n -> n 1")
pred_at_index = logprobs.gather(1, indices)
return -einops.rearrange(pred_at_index, "n 1 -> n")
(I1) collect rows
def collect_rows(matrix: Tensor, row_indexes: Tensor) -> Tensor:
"""Return a 2D matrix whose rows are taken from the input matrix in order according to row_indexes.
matrix: shape (m, n)
row_indexes: shape (k,). Each value is an integer in [0..m).
Return: shape (k, n). out[i] is matrix[row_indexes[i]].
"""
raise NotImplementedError()
matrix = t.arange(15).view((5, 3))
row_indexes = t.tensor([0, 2, 1, 0])
actual = collect_rows(matrix, row_indexes)
expected = t.tensor([[0, 1, 2], [6, 7, 8], [3, 4, 5], [0, 1, 2]])
assert_all_equal(actual, expected)
Solution
def collect_rows(matrix: Tensor, row_indexes: Tensor) -> Tensor:
"""Return a 2D matrix whose rows are taken from the input matrix in order according to row_indexes.
matrix: shape (m, n)
row_indexes: shape (k,). Each value is an integer in [0..m).
Return: shape (k, n). out[i] is matrix[row_indexes[i]].
"""
assert row_indexes.max() < matrix.shape[0]
return matrix[row_indexes]
(I2) collect columns
def collect_columns(matrix: Tensor, column_indexes: Tensor) -> Tensor:
"""Return a 2D matrix whose columns are taken from the input matrix in order according to column_indexes.
matrix: shape (m, n)
column_indexes: shape (k,). Each value is an integer in [0..n).
Return: shape (m, k). out[:, i] is matrix[:, column_indexes[i]].
"""
assert column_indexes.max() < matrix.shape[1]
raise NotImplementedError()
matrix = t.arange(15).view((5, 3))
column_indexes = t.tensor([0, 2, 1, 0])
actual = collect_columns(matrix, column_indexes)
expected = t.tensor([[0, 2, 1, 0], [3, 5, 4, 3], [6, 8, 7, 6], [9, 11, 10, 9], [12, 14, 13, 12]])
assert_all_equal(actual, expected)
Solution
def collect_columns(matrix: Tensor, column_indexes: Tensor) -> Tensor:
"""Return a 2D matrix whose columns are taken from the input matrix in order according to column_indexes.
matrix: shape (m, n)
column_indexes: shape (k,). Each value is an integer in [0..n).
Return: shape (m, k). out[:, i] is matrix[:, column_indexes[i]].
"""
assert column_indexes.max() < matrix.shape[1]
return matrix[:, column_indexes]
Einsum
Einsum is a very useful function for performing linear operations, which you'll probably be using a lot during this programme.
Note - we'll be using the
einops.einsumversion of the function, which works differently to the more conventionaltorch.einsum:
einops.einsumhas the arrays as the first arguments, and uses spaces to separate dimensions in the string.torch.einsumhas the string as its first argument, and doesn't use spaces to separate dimensions (each dim is represented by a single character).For instance,
torch.einsum("ij,i->j", A, b)is equivalent toeinops.einsum(A, b, "i j, i -> j"). (Note, einops doesn't care whether there are spaces either side of,and->, so you don't need to match this syntax exactly.)
Although there are many different kinds of operations you can perform, they are all derived from three key rules:
- Repeating letters in different inputs means those values will be multiplied, and those products will be in the output.
- For example,
M = einops.einsum(A, B, "i j, i j -> i j")just corresponds to the elementwise productM = A * B(because $M_{ij} = A_{ij} B_{ij}$).
- For example,
- Omitting a letter means that the axis will be summed over.
- For example, if
xis a 2D array with shape(I, J), theneinops.einsum(x, "i j -> i")will be a 1D array of lengthIcontaining the row sums ofx(we're summing along thej-index, i.e. across rows).
- For example, if
- We can return the unsummed axes in any order.
- For example,
einops.einsum(x, "i j k -> k j i")does the same thing aseinops.rearrange(x, "i j k -> k j i").
- For example,
Note - the einops creators supposedly have plans to support shape rearrangement, e.g. with operations like einops.einsum(x, y, "i j, j k l -> i (k l)") (i.e. combining the features of rearrange and einsum), so we can all look forward to that day!
Exercises - einsum
```yaml Difficulty: 🔴🔴⚪⚪⚪ Importance: 🔵🔵🔵🔵⚪
You should spend up to 15-20 minutes on these exercises collectively. If you think you get the general idea, then you can skip to the next section. ```
In the following exercises, you'll write simple functions using einsum which replicate the functionality of standard NumPy functions: trace, matrix multiplication, inner and outer products. We've also included some test functions which you should run.
Note - this version of einsum will require that you include ->, even if you're summing to a scalar (i.e. the right hand side of your string expression is empty).
def einsum_trace(mat: np.ndarray):
"""
Returns the same as `np.trace`.
"""
raise NotImplementedError()
def einsum_mv(mat: np.ndarray, vec: np.ndarray):
"""
Returns the same as `np.matmul`, when `mat` is a 2D array and `vec` is 1D.
"""
raise NotImplementedError()
def einsum_mm(mat1: np.ndarray, mat2: np.ndarray):
"""
Returns the same as `np.matmul`, when `mat1` and `mat2` are both 2D arrays.
"""
raise NotImplementedError()
def einsum_inner(vec1: np.ndarray, vec2: np.ndarray):
"""
Returns the same as `np.inner`.
"""
raise NotImplementedError()
def einsum_outer(vec1: np.ndarray, vec2: np.ndarray):
"""
Returns the same as `np.outer`.
"""
raise NotImplementedError()
tests.test_einsum_trace(einsum_trace)
tests.test_einsum_mv(einsum_mv)
tests.test_einsum_mm(einsum_mm)
tests.test_einsum_inner(einsum_inner)
tests.test_einsum_outer(einsum_outer)