关于pytorch中scatter_add_函数的分析、理解与实现

2400阅读 0评论2021-10-15 专注的阿熊
分类:Python/Ruby

import torch

import numpy as np

from torch import Tensor

"""

@overload

def scatter_add(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ...

@overload

def scatter_add(self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -> Tensor: ...

def scatter_add_(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ...

pytorch中的scatter_add函数的理解和简单测试:

# 参数:tensor,dim,index,tensor

# 返回:tensor

# 功能:将other_tensor的值累加到self_tensor的相应位置,用index_tensor对应位置的值替换掉self_tensor下标的dim

# 举例:

    self_tensor  = [[1, 2], [3, 4]] shape=(2,2)

    other_tensor = [[5, 6], [7, 8]] shape=(2,2)

    index_tensor = [[0, 0], [1, 1]] shape=(2,2)

    dim = 1

    以上三个tensorshape必须一致,下标为:[0,0] [0,1] [1,0] [1,1]

    dim=1,那么,self_tensor的第1维下标由index_tensor表示,[0,0] [0,0] [1,1] [1,1]

    :

        self_tensor[0,0] = 1 + 5 + 6 = 12

        self_tensor[0,1] = 2

        self_tensor[1,0] = 3

        self_tensor[1,1] = 4 + 7 + 8 = 19

"""

def scatter_add(input_tensor: torch.Tensor, dim: int, index: torch.Tensor, other: torch.Tensor) -> torch.Tensor:

    # tensor的维数是不确定的,因此无法用for循环的方式

    # 如果tensor2维,外汇跟单gendan5.com那么dim=01,两层for循环,用otherself进行填充

    # 如果tensor3维,那么dim=012,需要三层for循环来遍历other

    if input_tensor.dim() == 2:

        for i in range(index_tensor.size()[0]):

            for j in range(index_tensor.size()[1]):

                if dim == 0:  # self矩阵的第0维索引

                    self_tensor[index_tensor[i][j]][j] += other_tensor[i][j]

                elif dim == 1:  # self矩阵的第1维索引

                    self_tensor[i][index_tensor[i][j]] += other_tensor[i][j]

    elif input_tensor.dim() == 3:

        pass

    return self_tensor

if __name__ == '__main__':

    index_tensor = torch.tensor([[0, 0], [1, 1]])

    print('index_tensor: \n', index_tensor.dim())

    self_tensor = torch.arange(1, 5).view(2, 2)

    print('self_tensor: \n', self_tensor)

    other_tensor = torch.arange(5, 9).view(2, 2)

    print('other_tensor: \n', other_tensor)

    dim = 1

    for i in range(index_tensor.size()[0]):

        for j in range(index_tensor.size()[1]):

            replace_index = index_tensor[i][j]

            print(i, j, replace_index)

            if dim == 0:

                # self矩阵的第0维索引

                self_tensor[replace_index][j] += other_tensor[i][j]

            elif dim == 1:

                # self矩阵的第1维索引

                self_tensor[i][replace_index] += other_tensor[i][j]

    print(self_tensor)

    index_tensor = torch.tensor([[0, 1], [1, 1]])

    print('index_tensor: \n', index_tensor)

    self_tensor = torch.arange(0, 4).view(2, 2)

    print('self_tensor: \n', self_tensor)

    other_tensor = torch.arange(5, 9).view(2, 2)

    print('other_tensor: \n', other_tensor)

    self_tensor.scatter_add_(dim=0, index=index_tensor, src=other_tensor)

    print(self_tensor)

上一篇:简单实现登陆注册gui界面以及打包成exe文件
下一篇:python mysql学生成绩管理系统