NumPy 索引與切片:自由切割張量的方法

1. 為什麼索引/切片如此重要?



在深度學習中,張量經常被頻繁操作。

  • 從批次中 只取前幾個樣本
  • 從圖像中 挑選特定通道(R/G/B)
  • 在序列資料中 只切出部分時間步
  • 從標籤中 挑選特定類別

所有這些操作最終都歸結為 “索引(indexing) & 切片(slicing)”

而 PyTorch 的 Tensor 索引/切片語法與 NumPy 非常相似, 因此先熟悉 NumPy,對寫深度學習程式會大有幫助。


2. 基本索引:從 1 維開始

2.1 1 維陣列索引

import numpy as np

x = np.array([10, 20, 30, 40, 50])

print(x[0])  # 10
print(x[1])  # 20
print(x[4])  # 50
  • 索引從 0 開始
  • x[i] 取第 i 個元素

2.2 負數索引

若想從尾部開始計數,使用負數索引。

print(x[-1])  # 50 (最後一個)
print(x[-2])  # 40

PyTorch 也同樣適用。


3. 切片基礎:start:stop:step



切片寫法為 x[start:stop:step]

x = np.array([10, 20, 30, 40, 50])

print(x[1:4])    # [20 30 40], 1 以上 4 未滿
print(x[:3])     # [10 20 30], 從頭到 3 未滿
print(x[2:])     # [30 40 50], 2 之後到結束
print(x[:])      # 整體複製感覺

若加上 step,可設定間隔。

print(x[0:5:2])  # [10 30 50], 0~4 兩個間隔
print(x[::2])    # [10 30 50], 同上

在深度學習中,例如跳過每兩個時間步,或以固定間隔抽取樣本,都很實用。


4. 2 維以上陣列索引:行與列

從 2 維開始,實際上就是「矩陣/批次」,更貼近深度學習。

import numpy as np

X = np.array([[1, 2, 3],
              [4, 5, 6],
              [7, 8, 9]])  # shape: (3, 3)

4.1 行/列單一索引

print(X[0])      # 第一行: [1 2 3]
print(X[1])      # 第二行: [4 5 6]

print(X[0, 0])   # 1 行 1 列: 1
print(X[1, 2])   # 2 行 3 列: 6
  • X[i] → 第 i 行(1D 陣列)
  • X[i, j] → 第 i 行第 j 列的值(標量)

4.2 行切片

print(X[0:2])    # 0~1 行
# [[1 2 3]
#  [4 5 6]]

print(X[1:])     # 1 行之後到結束
# [[4 5 6]
#  [7 8 9]]

這是「批次中只取前/後部分樣本」的常見模式。

4.3 列切片

print(X[:, 0])   # 所有行, 0 列 → [1 4 7]
print(X[:, 1])   # 所有行, 1 列 → [2 5 8]
  • : 表示「該維度全部」
  • X[:, 0] 表示「行全都取,列只取第 0 列」

在深度學習中:

  • (batch_size, feature_dim) 陣列中 想取特定 feature:X[:, k]

5. 3 維以上:批次 × 通道 × 高 × 寬

以圖像資料為例。

# 假設 (batch, height, width)
images = np.random.randn(32, 28, 28)  # 32 張 28x28 圖像

5.1 取單一樣本

img0 = images[0]        # 第一張圖,shape: (28, 28)
img_last = images[-1]   # 最後一張

5.2 取部分批次

first_8 = images[:8]    # 前 8 張,shape: (8, 28, 28)

5.3 只裁剪圖像的一部分 (crop)

# 中央 20x20 區域
crop = images[:, 4:24, 4:24]  # shape: (32, 20, 20)

PyTorch 也類似:

# images_torch: (32, 1, 28, 28) 的張量
center_crop = images_torch[:, :, 4:24, 4:24]

索引/切片概念幾乎完全相同。


6. 切片通常是「視圖(view)」

重要點:

切片結果通常是 原陣列的「視圖(view)」。 也就是說,並未複製資料,而是「看著」原始資料。

x = np.array([10, 20, 30, 40, 50])
y = x[1:4]   # 視圖

print(y)     # [20 30 40]
y[0] = 999

print(y)     # [999  30  40]
print(x)     # [ 10 999  30  40  50]  ← 原始也改變!

這個特性:

  • 節省記憶體、速度快
  • 但不小心會改變原始資料

若想得到獨立陣列,使用 copy()

x = np.array([10, 20, 30, 40, 50])
y = x[1:4].copy()

y[0] = 999
print(x)  # [10 20 30 40 50], 原始保持

PyTorch 也有類似概念,掌握「視圖 vs 複製」能讓除錯更順利。


7. 布林索引:條件選取

布林索引用於「只挑選滿足條件的元素」。

import numpy as np

x = np.array([1, -2, 3, 0, -5, 6])

mask = x > 0
print(mask)      # [ True False  True False False  True]

pos = x[mask]
print(pos)       # [1 3 6]
  • x > 0 → 由 True/False 組成的陣列
  • x[mask] → 只取 True 的位置

組合範例:

X = np.array([[1, 2, 3],
              [4, 5, 6],
              [-1, -2, -3]])

pos = X[X > 0]
print(pos)  # [1 2 3 4 5 6]

在深度學習中常見:

  • 只取符合條件的樣本(如標籤為特定值)
  • 計算損失時,加上遮罩只平均部分值

PyTorch 也同樣使用:

import torch

x = torch.tensor([1, -2, 3, 0, -5, 6])
mask = x > 0
pos = x[mask]

8. 整數陣列 / 列表索引(Fancy Indexing)

可以用整數索引陣列或列表一次取多個位置。

x = np.array([10, 20, 30, 40, 50])

idx = [0, 2, 4]
print(x[idx])  # [10 30 50]

2 維也可用。

X = np.array([[1, 2],
              [3, 4],
              [5, 6]])  # shape: (3, 2)

rows = [0, 2]
print(X[rows])  
# [[1 2]
#  [5 6]]

在深度學習中:

  • 隨機打亂的索引陣列取「批次樣本」
  • 只取特定位置的標籤/預測統計

PyTorch 也可同樣使用。


9. 常用索引模式總結

從深度學習角度整理常見模式:

import numpy as np

# (batch, feature)
X = np.random.randn(32, 10)

# 1) 前 8 個樣本
X_head = X[:8]              # (8, 10)

# 2) 特定 feature(例如第 3 列)
f3 = X[:, 3]                # (32,)

# 3) 偶數索引樣本
X_even = X[::2]             # (16, 10)

# 4) 標籤為 1 的樣本
labels = np.random.randint(0, 3, size=(32,))
mask = labels == 1
X_cls1 = X[mask]            # 標籤 1 的樣本

# 5) 隨機打亂後,前 24 個作訓練,後 8 個作驗證
indices = np.random.permutation(len(X))
train_idx = indices[:24]
val_idx = indices[24:]

X_train = X[train_idx]
X_val = X[val_idx]

這些模式在 PyTorch 張量中幾乎完全相同。

總結:熟悉 NumPy 索引即等於能自由「切、混、選」張量。


10. 結語

本文整理了:

  • 索引:x[i], x[i, j], 負數索引
  • 切片:start:stop:step, :, 多維索引(X[:, 0], X[:8], X[:, 4:8]
  • 切片通常是 視圖(view),如需獨立請用 copy()
  • 布林索引:條件過濾(x[x > 0], X[labels == 1]
  • 整數陣列索引:一次選取多個位置(x[[0,2,4]]

掌握這些技巧後,使用 PyTorch 張量進行「批次切割 / 通道選取 / 遮罩應用 / 樣本混洗」等工作會更順手。

image