es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

BatchDataset muestra imágenes y etiquetas

Tengo un conjunto de datos de lotes de Entrenamiento y Validación:

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_path,
    label_mode = 'categorical', # se utiliza para clasificación multiclase. Son etiquetas codificadas en one-hot para cada clase
    validation_split = 0.2,     # porcentaje del conjunto de datos que se considerará para validación
    subset = 'training',        # este subconjunto se utiliza para entrenamiento
    seed = 1337,                # se establece la semilla para que los resultados sean reproducibles
    image_size = img_size,      # tamaño de las imágenes de entrada
    batch_size = batch_size,    # esto debe coincidir con el tamaño del lote del modelo
)

valid_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_path,
    label_mode ='categorical',
    validation_split = 0.2,
    subset = 'validation',      # este subconjunto se utiliza para validación
    seed = 1337,
    image_size = img_size,
    batch_size = batch_size,
)

Estaba intentando mostrar 9 imágenes para mostrar cómo se ven, lo cual logré, pero no puedo mostrar su respectiva etiqueta.

Aquí está el código:

class_names = train_ds.class_names


plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.axis("off")

Lo cual muestra esto:

enter image description here

Si intento obtener las etiquetas agregando: plt.title(class_names[labels[i]])

Obtengo el siguiente error: TypeError: solo se pueden convertir matrices escalares enteras en un índice escalar

He intentado soluciones de otras publicaciones como la siguiente: plt.title(class_names[labels[i][0]]), pero sin éxito.

Cuando imprimo labels[i] obtengo la codificación one-hot de las etiquetas… tal vez ahí está el problema?

enter image description here

El código resultante:

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[np.argmax(labels[i], axis=None, out=None)])
    plt.axis("off")
Tags:  , , ,

Answer

  1. Avatar for davy.ai

    Para trazar las etiquetas respectivas de las imágenes en el conjunto de datos, primero debes convertir las etiquetas codificadas en one-hot a enteros. Puedes usar np.argmax() para obtener el índice del elemento máximo en el arreglo de etiquetas, que corresponde a la etiqueta de clase. Aquí está el código modificado:

    plt.figure(figsize=(10, 10))
    for images, labels in train_ds.take(1):
      for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[np.argmax(labels[i])])
        plt.axis("off")
    

    Aquí, np.argmax(labels[i]) obtiene el índice del valor máximo en el arreglo de etiquetas para la i-ésima imagen, que corresponde a la etiqueta de clase. Luego, class_names[np.argmax(labels[i])] obtiene el nombre de clase correspondiente de la lista class_names.

    El resultado debería ser 9 imágenes con sus respectivas etiquetas mostradas.

Comments are closed.