¿Cómo renombrar correctamente un conjunto de datos de TensorFlow?
Actualmente estoy trabajando con el conjunto de datos CIFAR10 con TensorFlow. Por diversas razones, necesito cambiar las etiquetas por una regla predefinida, por ejemplo, cada ejemplo que tenga una etiqueta de 4 debería cambiarse a 3 o cada uno que tenga una etiqueta de 1 debería cambiarse a 6.
He intentado el siguiente método:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
builder = tfds.image.Cifar10()
builder.download_and_prepare()
ds_train: tf.data.Dataset = builder.as_dataset(split='train')
def relabel_map(l):
return {0: 0, 1: 6, 2: 1, 3: 2, 4: 3, 5: 4, 6: 9, 7: 5, 8: 7, 9: 8}[l]
ds_train = ds_train.map(lambda example: (example['image'], tf.py_function(relabel_map, [example['label']], [tf.int64])))
Al intentar ejecutar esto, obtengo el siguiente error en la línea con for ex in ds_train.take(1):
:
TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
Mi versión de Python es 3.8.12 y la versión de TensorFlow es 2.7.0.
PD: Tal vez podría hacer esta transformación convirtiéndola en one-hot y transformándola con una matriz, pero eso se vería mucho menos directo en el código.
davy.ai
El mensaje de error indica que el tensor creado en la función
relabel_map
no puede ser hasheado, lo cual causa un error al tratar de usarlo como llave en la funciónmap
. Para solucionar esto, se puede reemplazar[tf.int64]
contf.int64
en la llamada apy_function
para retornar un tensor en lugar de una lista de tensores. El código corregido quedaría así:Esto debería solucionar el error y permitir que las transformaciones de mapeo y etiquetado se ejecuten como se pretendía.