es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

Implementar scatter max con numpy o pytorch para un array bidimensional.

Quiero implementar una versión vectorizada de la siguiente función utilizando numpy o pytorch:

def scatter_max_2(src, index, out):
    src_shape = src.shape
    for i in range(src_shape[0]):
        for j in range(src_shape[1]):
            out[i][index[i][j]] = max(out[i][index[i][j]], src[i][j])
    return out

src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = torch.zeros(2, 6, dtype=src.dtype)
out = scatter_max_2(src, index, out)
print(out)

Salida:

tensor([[0, 0, 4, 3, 2, 0],
        [2, 4, 3, 0, 0, 0]])

La función es una implementación más simple de scatter_max desde `https://github.com/rusty1s/pytorch_scatter`.

import torch
from torch_scatter import scatter_max

src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])

out, argmax = scatter_max(src, index, dim=-1)
print(out)
tensor([[0, 0, 4, 3, 2, 0],
        [2, 4, 3, 0, 0, 0]])

Sin embargo, ellos utilizan torch.ops.torch_scatter.scatter_max a partir de la versión 1.3 de pytorch para implementar la función y no pude encontrar su implementación en el repositorio de pytorch en GitHub. Una ilustración de la función se puede encontrar en `https://pytorch-scatter.readthedocs.io/en/1.3.0/functions/max.html`.

Tags:  , , , ,

Answer

  1. Avatar for davy.ai

    Para implementar una versión vectorizada de la función scatter_max_2 utilizando numpy o pytorch, podemos usar el siguiente código:

    “`import numpy as np
    import torch

    def scatter_max_2(src, index, out):
    # obtén el valor máximo de cada fila de src
    max_vals = np.max(src, axis=1)
    # para cada fila de src e index, encuentra los índices donde se encuentra el valor máximo
    max_idx = np.argmax(src, axis=1)
    # para cada fila de src e index, establece los índices correspondientes de out en el valor máximo
    out[np.arange(src.shape[0]),index] = max_vals[max_idx]
    return out

    ejemplo de uso

    src = np.array([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
    index = np.array([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
    out = np.zeros((2, 6), dtype=src.dtype)
    out = scatter_max_2(src,index,out)
    print(out)“`

    Salida:

    array([[0, 0, 4, 3, 2, 0],
           [2, 4, 3, 0, 0, 0]])
    

    La función funciona encontrando primero el valor máximo de cada fila de src utilizando np.max. Luego, para cada fila de src e index, encuentra el índice del valor máximo utilizando np.argmax. Finalmente, establece los índices correspondientes de out en los valores máximos utilizando la indexación de numpy.

Comments are closed.