NumPy 索引与切片:自由裁剪张量的方法

1. 为什么索引/切片如此重要?



在深度学习中,你会频繁地处理张量。

  • 从批次中 仅取前几个样本
  • 在图像中 仅挑选特定通道(R/G/B)
  • 在序列数据中 仅裁剪部分时间步
  • 在标签中 仅挑选特定类别

所有这些操作最终都归结为 “索引(indexing) & 切片(slicing)”

而 PyTorch 的 Tensor 索引/切片语法与 NumPy 非常相似, 因此在 NumPy 上熟练掌握后,编写深度学习代码会更得心应手。


2. 基础索引:从 1 维开始

2.1 1 维数组索引

import numpy as np

x = np.array([10, 20, 30, 40, 50])

print(x[0])  # 10
print(x[1])  # 20
print(x[4])  # 50
  • 索引从 0 开始
  • x[i] 是第 i 个元素

2.2 负数索引

当你想从末尾开始计数时,可以使用负数索引。

print(x[-1])  # 50 (最后一个)
print(x[-2])  # 40

PyTorch 也同样支持。


3. 切片基础:start:stop:step



切片使用 x[start:stop:step] 的形式。

x = np.array([10, 20, 30, 40, 50])

print(x[1:4])    # [20 30 40], 1 以上 4 未满
print(x[:3])     # [10 20 30], 从开始到 3 未满
print(x[2:])     # [30 40 50], 从 2 到末尾
print(x[:])      # 整体复制的感觉

如果加上 step,可以设置间隔。

print(x[0:5:2])  # [10 30 50], 0~4 每隔 2
print(x[::2])    # [10 30 50], 同上

在深度学习中,例如跳过每 2 步时间步,或按固定间隔抽样时非常有用。


4. 2 维及以上数组索引:行与列

从 2 维开始,基本上就是 “矩阵/批次”,更贴近深度学习的语境。

import numpy as np

X = np.array([[1, 2, 3],
              [4, 5, 6],
              [7, 8, 9]])  # shape: (3, 3)

4.1 行/列单一索引

print(X[0])      # 第一行: [1 2 3]
print(X[1])      # 第二行: [4 5 6]

print(X[0, 0])   # 第 1 行第 1 列: 1
print(X[1, 2])   # 第 2 行第 3 列: 6
  • X[i] → 第 i 行(1D 数组)
  • X[i, j] → 第 i 行第 j 列的值(标量)

4.2 行切片

print(X[0:2])    # 0~1 行
# [[1 2 3]
#  [4 5 6]]

print(X[1:])     # 从第 1 行到末尾
# [[4 5 6]
#  [7 8 9]]

这通常对应 批次中前/后部分样本裁剪 的模式。

4.3 列切片

print(X[:, 0])   # 所有行,第 0 列 → [1 4 7]
print(X[:, 1])   # 所有行,第 1 列 → [2 5 8]
  • : 表示 “该维度的全部”
  • X[:, 0] 表示 “行全,列 0”

在深度学习中:

  • 对于 (batch_size, feature_dim) 数组 想取特定特征时:X[:, k]

5. 3 维及以上:批次 × 通道 × 高 × 宽

以图像数据为例。

# 假设 (batch, height, width)
images = np.random.randn(32, 28, 28)  # 32 张 28x28 图像

5.1 取单个样本

img0 = images[0]        # 第一张图,shape: (28, 28)
img_last = images[-1]   # 最后一张

5.2 取部分批次

first_8 = images[:8]    # 前 8 张,shape: (8, 28, 28)

5.3 裁剪图像区域(crop)

# 中心 20x20 区域
crop = images[:, 4:24, 4:24]  # shape: (32, 20, 20)

PyTorch 也类似:

# images_torch: (32, 1, 28, 28) 的张量
center_crop = images_torch[:, :, 4:24, 4:24]

可以看到 索引/切片概念几乎完全相同


6. 切片通常是 “视图(view)”

一个重要点:

切片结果通常是 原数组的“视图(view)”。 换句话说,它不是复制数据,而是 看原始数据的窗口

x = np.array([10, 20, 30, 40, 50])
y = x[1:4]   # 视图

print(y)     # [20 30 40]
y[0] = 999

print(y)     # [999  30  40]
print(x)     # [ 10 999  30  40  50]  ← 原始也被改动!

这种特性意味着:

  • 节省内存,速度更快
  • 但不小心会改动原始数据

如果想得到 完全独立的数组,请使用 copy()

x = np.array([10, 20, 30, 40, 50])
y = x[1:4].copy()

y[0] = 999
print(x)  # [10 20 30 40 50],原始保持不变

PyTorch 也有类似概念,掌握 “视图 vs 复制” 的感觉能让调试更顺畅。


7. 布尔索引:按条件挑选元素

布尔索引用于 只挑选满足条件的元素

import numpy as np

x = np.array([1, -2, 3, 0, -5, 6])

mask = x > 0
print(mask)      # [ True False  True False False  True]

pos = x[mask]
print(pos)       # [1 3 6]
  • x > 0 → 由 True/False 组成的数组
  • x[mask] → 仅取 True 的位置

组合示例:

X = np.array([[1, 2, 3],
              [4, 5, 6],
              [-1, -2, -3]])

pos = X[X > 0]
print(pos)  # [1 2 3 4 5 6]

在深度学习中常见的用法:

  • 只提取满足特定条件的样本(如标签为某值)
  • 在损失计算时,加上掩码 只平均部分值

PyTorch 也几乎一样:

import torch

x = torch.tensor([1, -2, 3, 0, -5, 6])
mask = x > 0
pos = x[mask]

8. 整数数组/列表索引(Fancy Indexing)

可以用整数索引的数组/列表一次性挑选多个位置。

x = np.array([10, 20, 30, 40, 50])

idx = [0, 2, 4]
print(x[idx])  # [10 30 50]

二维也可。

X = np.array([[1, 2],
              [3, 4],
              [5, 6]])  # shape: (3, 2)

rows = [0, 2]
print(X[rows])  
# [[1 2]
#  [5 6]]

在深度学习中,例如:

  • 用随机打乱的索引数组 抽取批次样本
  • 只聚合特定位置的标签/预测进行统计

PyTorch 也同样支持。


9. 常用索引模式汇总

从深度学习角度整理常见模式:

import numpy as np

# (batch, feature)
X = np.random.randn(32, 10)

# 1) 前 8 个样本
X_head = X[:8]              # (8, 10)

# 2) 特定特征(如第 3 列)
f3 = X[:, 3]                # (32,)

# 3) 偶数索引样本
X_even = X[::2]             # (16, 10)

# 4) 标签为 1 的样本
labels = np.random.randint(0, 3, size=(32,))
mask = labels == 1
X_cls1 = X[mask]            # 仅标签 1 的样本

# 5) 随机打乱后,前 24 作为 train,后 8 作为 val
indices = np.random.permutation(len(X))
train_idx = indices[:24]
val_idx = indices[24:]

X_train = X[train_idx]
X_val = X[val_idx]

这些模式在 PyTorch 张量中几乎可以直接复用。 最终 熟练掌握 NumPy 索引 就等同于 能自由“切、混、挑”张量 的能力。


10. 结语

总结本篇内容:

  • 索引:x[i]x[i, j]、负数索引
  • 切片:start:stop:step:、多维索引(X[:, 0]X[:8]X[:, 4:8]
  • 切片通常是 视图(view),若需独立副本请使用 copy()
  • 布尔索引:按条件过滤(x[x > 0]X[labels == 1]
  • 整数数组索引:一次性挑选(x[[0,2,4]]

掌握这些后,你就能用 PyTorch 张量 轻松完成批次裁剪 / 通道选择 / 掩码应用 / 样本混洗 等常见任务。

image