Cómo usar tf.gather_nd para un tensor multidimensional
No entiendo completamente cómo debo usar tf.gather_nd() para seleccionar elementos a lo largo de un eje si tengo un tensor multidimensional. Tomemos un pequeño ejemplo (si obtengo una respuesta para este ejemplo simple, también resolverá mi problema original más complejo). Digamos que tengo una imagen RGB y estoy tratando de seleccionar el valor de píxel más pequeño a lo largo de los canales (última dimensión si el orden de los datos es (B, H, W, C)). Sé que esto se puede hacer con tf.reduce_min(x, axis=-1) pero me gustaría saber si también es posible hacer lo mismo con tf.argmin() y tf.gather_nd()?
from skimage import data
import tensorflow as tf
import numpy as np
# Cargar imagen RGB de skimage, convertirla a float32 y colocarla en el orden (B, H, W, C)
image = data.astronaut()
image = tf.cast(image, tf.float32)
image = tf.expand_dims(image, axis=0)
# Tomar el valor de píxel mínimo de cada canal de una manera
min_along_channels_1 = tf.reduce_min(image, axis=-1)
# Tomar el valor de píxel mínimo de cada canal de otra manera
# El objetivo es que min_along_channels_1 sea igual a min_along_channels_2
idxs = tf.argmin(image, axis=-1)
min_along_channels_2 = tf.gather_nd(image, idxs) # ¡Esta línea provoca un error! :(
davy.ai
Para usar
tf.gather_nd()
para recoger elementos a lo largo de un eje específico, primero necesitas obtener los índices de los elementos que deseas recoger usandotf.argmin()
. En el ejemplo dado, el objetivo es recoger el valor mínimo de píxeles a lo largo de los canales (última dimensión) de una imagen RGB. Esto se puede lograr usando el siguiente código:En el código anterior,
tf.expand_dims(idxs, axis=-1)
se utiliza para expandir el tensor de índices para tener una dimensión extra (es decir, de forma(B,H,W)
a(B,H,W,1)
), que es requerido portf.gather_nd()
. Luego se utiliza la función para recoger los elementos indexados poridxs
a lo largo de la última dimensión. Finalmente,tf.squeeze()
se utiliza para eliminar la dimensión extra añadida portf.expand_dims()
.Después de ejecutar el código, tanto
min_along_channels_1
comomin_along_channels_2
deberían contener el mismo valor mínimo de píxel a lo largo de los canales.