Chapter 4: Array reshaping

Squeezing an array

face Josiah Wang

Yes, much better! Now, here is a new mission for you. Get to it, trainee!

Mission 5: Removing singleton dimensions

I was just doing a machine learning experiment, and I somehow ended up with a (3 \times 1 \times 2) array. Is that second axis really necessary? Can you just help me squeeze that pesky axis out and just make it (3 \times 2) please?

>>> x = np.array([[[1, 2]], [[3, 4]], [[5, 6]]])
>>> print(x.shape)
(3, 1, 2)
>>> print(x) # Note all those extra brackets!
[[[1 2]]

 [[3 4]]

 [[5 6]]]
>>> simpler_x = ????
>>> assert np.all(simpler_x == np.array([[1, 2],
...                               [3, 4],
...                               [5, 6]])))
>>> assert simpler_x.shape == (3, 2)

The .squeeze() method removes any singleton dimensions (axes of length 1).

simpler_x = x.squeeze()

# Function version
simpler_x = np.squeeze(x)

If you have multiple singleton axes, e.g. (3, 1, 1, 2), and need to remove only one of them, you can specify this when you squeeze.

>>> x = np.array([[[0], [0], [0]]])
>>> print(x.shape)
(1, 3, 1)
>>> print(x)
[[[0]
  [0]
  [0]]]
>>> y = x.squeeze(axis=0)
>>> print(y.shape)
(3, 1)
>>> print(y)
[[0]
 [0]
 [0]]
>>> y = x.squeeze(axis=2)
>>> print(y.shape)
(1, 3)
>>> print(y)
[[0 0 0]]