¿Se puede utilizar SigmoidFocalCrossEntropy en Tensorflow (tf-addons) para la clasificación multiclase? (¿Cuál es la manera correcta?)
La Pérdida Focal proporcionada en Tensorflow se utiliza para el desequilibrio de clases. Para la clasificación binaria, hay muchos códigos disponibles, pero para la clasificación multiclase, hay muy poca ayuda. Ejecuté el código con variables objetivo codificadas en “One Hot Encoding” de 250 clases y me dio resultados sin errores.
y = pd.get_dummies(df['target']) # Clases objetivo codificadas en "One hot encoding"
model.compile(
optimizer="adam", loss=tfa.losses.SigmoidFocalCrossEntropy(), metrics=metric
)
Solo quiero saber quién escribió este código o alguien que tenga suficiente conocimiento de este código, ¿se puede usar para la clasificación multiclase? Si no, ¿cómo es posible que no me haya dado errores, en cambio, mejores resultados que la Entropía Cruzada? Además, en otras implementaciones como esta, el valor de alpha
debe proporcionarse para cada clase, pero en las implementaciones de Tensorflow solo se proporciona un valor.
¿Cuál es la forma correcta de usar esto?
davy.ai
La función
tfa.losses.SigmoidFocalCrossEntropy
en TensorFlow Addons está diseñada para manejar la clasificación multiclase además de la clasificación binaria. Como has mencionado, puede ser utilizada con variables objetivo codificadas en one-hot para la clasificación multiclase. Es posible que esta implementación maneje mejor el problema del desequilibrio de clases que la función de pérdidaCrossEntropy
estándar, lo que podría explicar por qué estás obteniendo mejores resultados.En cuanto al valor
alpha
, la implementación en TensorFlow Addons utiliza un esquema de ponderación por defecto en el cual cada clase recibe igual peso. Sin embargo, puedes especificar un esquema de ponderación personalizado pasando el argumentofrom_logits=True
y luego especificando un valor personalizado paraalpha
como un tensor de forma(num_classes,)
.Por lo tanto, la manera correcta de utilizar
tfa.losses.SigmoidFocalCrossEntropy
para la clasificación multiclase es pasar los logits de tu modelo y las variables objetivo codificadas en one-hot para cada muestra, y compilar el modelo utilizando esta función de pérdida. Si deseas utilizar un esquema de ponderación personalizado, debes establecerfrom_logits=True
y pasar un tensor de valores alpha personalizados.