目录

Pytorch-张量的scatter_add_方法介绍

Pytorch 张量的scatter_add_方法介绍

torch.Tensor.scatter_add_ 是 PyTorch 中的一个原地操作(in-place operation),用于将一个源张量( src )中的值根据指定的索引( index )累加到目标张量( self )中。它常用于分布式计算、加权聚合以及自定义深度学习层等场景。

函数签名

Tensor.scatter_add_(dim, index, src) → Tensor
参数说明
  1. dim (int) :指定沿着哪个维度进行索引和累加。
  2. index (LongTensor) :一个整数类型的张量,包含要累加的索引位置。 index 的形状应与 src 相同,除了指定的维度 dim
  3. src (Tensor) :源张量,包含要累加到目标张量的值。

功能

scatter_add_ 会根据 index 中的索引,将 src 中的值累加到目标张量 self 的指定位置。对于每个值,其目标位置由 index 指定,而其他维度的位置由其在 src 中的位置决定。

操作逻辑

对于一个三维张量, scatter_add_ 的更新规则如下:

self[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2

示例

以下是一个简单的二维张量示例:

Python复制

import torch

# 初始化目标张量
input_tensor = torch.zeros(3, 5)

# 源张量
src = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], dtype=torch.float32)

# 索引张量
index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]], dtype=torch.long)

# 沿着维度 0 进行累加
input_tensor.scatter_add_(0, index, src)

print(input_tensor)

输出:

tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 0.,  7.,  0.,  9.,  0.],
        [ 6.,  0.,  8.,  0., 10.]])

详细解析

  1. 目标张量input_tensor 是一个形状为 (3, 5) 的零张量。

  2. 源张量src 是一个形状为 (2, 5) 的张量,包含要累加的值。

  3. 索引张量index 是一个形状为 (2, 5) 的整数张量,指定 src 中的值应该累加到 input_tensor 的哪些位置。

  4. 累加操作

    • scatter_add_ 沿着维度 0 进行操作。

    • index 中的每个值指定了 src 中对应值的目标位置。

    • 例如:

      • index[0, 0] = 0 ,表示 src[0, 0] = 1 应该累加到 input_tensor[0, 0]
      • index[1, 1] = 0 ,表示 src[1, 1] = 7 应该累加到 input_tensor[0, 1]

注意事项

  1. 形状要求

    • indexsrc 的形状必须与目标张量 self 的形状兼容。
    • index.size(d) <= src.size(d) 对所有维度 d 成立。
    • index.size(d) <= self.size(d) 对所有维度 d != dim 成立。
  2. 非确定性行为

    • 在 CUDA 设备上, scatter_add_ 的行为可能是非确定性的。
  3. 反向传播

    • 反向传播仅在 src.shape == index.shape 时实现。
  4. 原地操作

    • scatter_add_ 是一个原地操作,会直接修改目标张量 self

总结

torch.Tensor.scatter_add_ 是一个强大的工具,用于将源张量中的值根据索引累加到目标张量中。它在处理稀疏更新和聚合操作时非常有用,尤其适合需要在特定位置累加值的场景。