Introduction to NumPy and Matplotlib
Chapter 4: Array reshaping
Squeezing an array
Yes, much better! Now, here is a new mission for you. Get to it, trainee!
Mission 5: Removing singleton dimensions
I was just doing a machine learning experiment, and I somehow ended up with a (3 \times 1 \times 2) array. Is that second axis really necessary? Can you just help me squeeze that pesky axis out and just make it (3 \times 2) please?
>>> x = np.array([[[1, 2]], [[3, 4]], [[5, 6]]])
>>> print(x.shape)
(3, 1, 2)
>>> print(x) # Note all those extra brackets!
[[[1 2]]
[[3 4]]
[[5 6]]]
>>> simpler_x = ????
>>> assert np.all(simpler_x == np.array([[1, 2],
... [3, 4],
... [5, 6]])))
>>> assert simpler_x.shape == (3, 2)
The .squeeze()
method removes any singleton dimensions (axes of length 1).
simpler_x = x.squeeze()
# Function version
simpler_x = np.squeeze(x)
If you have multiple singleton axes, e.g. (3, 1, 1, 2)
, and need to remove only one of them, you can specify this when you squeeze.
>>> x = np.array([[[0], [0], [0]]])
>>> print(x.shape)
(1, 3, 1)
>>> print(x)
[[[0]
[0]
[0]]]
>>> y = x.squeeze(axis=0)
>>> print(y.shape)
(3, 1)
>>> print(y)
[[0]
[0]
[0]]
>>> y = x.squeeze(axis=2)
>>> print(y.shape)
(1, 3)
>>> print(y)
[[0 0 0]]