Skip to content

NumPy : Suppression des dimensions de taille 1 de ndarray (np.squeeze)

Vous pouvez utiliser numpy.squeeze() pour supprimer toutes les dimensions de taille 1 du tableau NumPy ndarray. squeeze() est également fourni comme méthode de ndarray.

Cet article décrit le contenu suivant.

  • Utilisation de base de numpy.squeeze()
  • Spécifiez la dimension à supprimer :axis
  • numpy.ndarray.squeeze()

Utilisez numpy.reshape() pour convertir en n’importe quelle forme, et numpy.newaxis, numpy.expand_dims() pour ajouter une nouvelle dimension de taille 1. Voir l’article suivant pour plus de détails.

Utilisation de base de numpy.squeeze()

Spécifier numpy.ndarray comme premier argument de numpy.squeeze() renvoie numpy.ndarray avec toutes les dimensions de taille 1 supprimées.

import numpy as np

a = np.arange(6).reshape(1, 2, 1, 3, 1)
print(a)
# [[[[[0]
#     [1]
#     [2]]]
# 
# 
#   [[[3]
#     [4]
#     [5]]]]]

print(a.shape)
# (1, 2, 1, 3, 1)

a_s = np.squeeze(a)
print(a_s)
# [[0 1 2]
#  [3 4 5]]

print(a_s.shape)
# (2, 3)

numpy.squeeze() renvoie une vue du numpy.ndarray d’origine. L’objet d’origine et l’objet de vue partagent la mémoire, donc la modification d’un élément modifie l’autre.

print(np.shares_memory(a, a_s))
# True

Si vous voulez faire une copie, utilisez copy().

a_s_copy = np.squeeze(a).copy()

print(np.shares_memory(a, a_s_copy))
# False

Consultez l’article suivant pour les vues et les copies dans NumPy.

Spécifiez la dimension à supprimer :axis

Par défaut, toutes les cotes de taille 1 sont supprimées, comme dans l’exemple ci-dessus.

Vous pouvez spécifier l’index de la dimension à supprimer dans le deuxième axe d’argument de numpy.squeeze(). Les cotes qui ne correspondent pas à l’index spécifié ne sont pas supprimées.

print(a.shape)
# (1, 2, 1, 3, 1)
print(np.squeeze(a, 0))
# [[[[0]
#    [1]
#    [2]]]
# 
# 
#  [[[3]
#    [4]
#    [5]]]]

print(np.squeeze(a, 0).shape)
# (2, 1, 3, 1)

Une erreur se produit si vous spécifiez une dimension dont la taille n’est pas 1 ou une dimension qui n’existe pas.

# print(np.squeeze(a, 1))
# ValueError: cannot select an axis to squeeze out which has size not equal to one

# print(np.squeeze(a, 5))
# AxisError: axis 5 is out of bounds for array of dimension 5

L’axe peut également être spécifié comme une valeur négative. -1 correspond à la dernière dimension et peut être spécifié par la position depuis l’arrière.

print(np.squeeze(a, -1))
# [[[[0 1 2]]
# 
#   [[3 4 5]]]]

print(np.squeeze(a, -1).shape)
# (1, 2, 1, 3)

print(np.squeeze(a, -3))
# [[[[0]
#    [1]
#    [2]]
# 
#   [[3]
#    [4]
#    [5]]]]

print(np.squeeze(a, -3).shape)
# (1, 2, 3, 1)

Vous pouvez spécifier plusieurs dimensions avec des tuples. Une erreur se produit si une dimension dont la taille n’est pas 1 ou n’existe pas est incluse.

print(np.squeeze(a, (0, -1)))
# [[[0 1 2]]
# 
#  [[3 4 5]]]

print(np.squeeze(a, (0, -1)).shape)
# (2, 1, 3)

# print(np.squeeze(a, (0, 1)))
# ValueError: cannot select an axis to squeeze out which has size not equal to one

numpy.ndarray.squeeze()

squeeze() est également fourni comme méthode de numpy.ndarray.

L’utilisation est la même que numpy.squeeze(). Le premier argument est l’axe.

print(a.squeeze())
# [[0 1 2]
#  [3 4 5]]

print(a.squeeze().shape)
# (2, 3)

print(a.squeeze((0, -1)))
# [[[0 1 2]]
# 
#  [[3 4 5]]]

print(a.squeeze((0, -1)).shape)
# (2, 1, 3)

La méthode squeeze() renvoie également une vue comme numpy.squeeze(). L’objet d’origine reste le même.

a_s = a.squeeze()
print(a_s)
# [[0 1 2]
#  [3 4 5]]

print(np.shares_memory(a, a_s))
# True

print(a)
# [[[[[0]
#     [1]
#     [2]]]
# 
# 
#   [[[3]
#     [4]
#     [5]]]]]