es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

¿Cómo renombrar correctamente un conjunto de datos de TensorFlow?

Actualmente estoy trabajando con el conjunto de datos CIFAR10 con TensorFlow. Por diversas razones, necesito cambiar las etiquetas por una regla predefinida, por ejemplo, cada ejemplo que tenga una etiqueta de 4 debería cambiarse a 3 o cada uno que tenga una etiqueta de 1 debería cambiarse a 6.

He intentado el siguiente método:

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

builder = tfds.image.Cifar10()
builder.download_and_prepare()
ds_train: tf.data.Dataset = builder.as_dataset(split='train')

def relabel_map(l):
    return {0: 0, 1: 6, 2: 1, 3: 2, 4: 3, 5: 4, 6: 9, 7: 5, 8: 7, 9: 8}[l]

ds_train = ds_train.map(lambda example: (example['image'], tf.py_function(relabel_map, [example['label']], [tf.int64])))

Al intentar ejecutar esto, obtengo el siguiente error en la línea con for ex in ds_train.take(1)::

TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.

Mi versión de Python es 3.8.12 y la versión de TensorFlow es 2.7.0.

PD: Tal vez podría hacer esta transformación convirtiéndola en one-hot y transformándola con una matriz, pero eso se vería mucho menos directo en el código.

Tags:  , , ,

Answer

  1. Avatar for davy.ai

    El mensaje de error indica que el tensor creado en la función relabel_map no puede ser hasheado, lo cual causa un error al tratar de usarlo como llave en la función map. Para solucionar esto, se puede reemplazar [tf.int64] con tf.int64 en la llamada a py_function para retornar un tensor en lugar de una lista de tensores. El código corregido quedaría así:

    import matplotlib.pyplot as plt
    import numpy as np
    import tensorflow as tf
    import tensorflow_datasets as tfds
    
    builder = tfds.image.Cifar10()
    builder.download_and_prepare()
    ds_train: tf.data.Dataset = builder.as_dataset(split='train')
    
    def relabel_map(l):
        return {0: 0, 1: 6, 2: 1, 3: 2, 4: 3, 5: 4, 6: 9, 7: 5, 8: 7, 9: 8}[l]
    
    ds_train = ds_train.map(lambda example: (example['image'], tf.py_function(relabel_map, [example['label']], tf.int64)))
    
    for ex in ds_train.take(1):
        plt.imshow(np.array(ex[0], dtype=np.uint8))
        plt.show()
        print(ex[1])
    

    Esto debería solucionar el error y permitir que las transformaciones de mapeo y etiquetado se ejecuten como se pretendía.

Comments are closed.