es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

PyTorch cómo establecer el valor mínimo a cero a lo largo de cada fila de un tensor.

No conocemos la forma del tensor de entrada y no deberíamos usar ningún bucle. Solo operaciones de reducción e indexación. ¿Cómo establecemos el valor mínimo de cada fila en cero?

Por ejemplo:

entrada:

x = torch.tensor([[
[10, 20, 30]
[2, 5, 1]
]])

salida:

torch.tensor([
[0, 20, 30],
[2, 5, 0]
])

No pude resolverlo y no pude encontrar ninguna pregunta relacionada. Estoy atascado.

Tags:  , ,

Answer

  1. Avatar for davy.ai

    Una posible solución para establecer el valor mínimo de cada fila en cero sin usar bucles es primero encontrar los valores mínimos a lo largo de la dimensión de las filas utilizando la función torch.min() con el argumento “dim” establecido en 2. Luego, podemos usar la indexación avanzada para establecer los valores mínimos en cero. Aquí está el código:

    import torch
    
    # Tensor de entrada de ejemplo
    x = torch.tensor([
        [
            [10, 20, 30],
            [2, 5, 1]
        ]
    ])
    
    # Encontrar los valores mínimos a lo largo de la dimensión de las filas
    min_vals, _ = torch.min(x, dim=2, keepdim=True)
    
    # Crear una máscara booleana de dónde se encuentran los valores mínimos
    mask = x == min_vals
    
    # Establecer los valores mínimos en cero utilizando la indexación avanzada
    x[mask] = 0
    
    # Imprimir el tensor modificado
    print(x)
    

    Esto debería producir la salida deseada:

    tensor([
        [0, 20, 30],
        [2, 5, 0]
    ])
    

    Nota que el tensor de salida tiene la misma forma que el tensor de entrada, por lo que esta solución debería funcionar para tensores de entrada de cualquier forma siempre y cuando tengan al menos una dimensión de fila.

Comments are closed.