es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

¿Cómo utilizar correctamente ImageDataGenerator en Keras?

Últimamente, he estado jugando con la aumentación de datos en Keras y estoy utilizando el generador básico de imágenes (ImageDataGenerator). Aprendí de la manera difícil que es en realidad un generador y no un iterador (porque type(train_aug_ds) da como resultado <class 'keras.preprocessing.image.directoryiterator'="">, pensé que era un iterador). También revisé algunos blogs para su uso, pero no responden todas mis preguntas.

Por lo tanto, cargué mis datos de la siguiente manera:

train_aug = ImageDataGenerator(
    rescale=1./255,
    horizontal_flip=True,
    height_shift_range=0.1,
    width_shift_range=0.1,
    brightness_range=(0.5,1.5),
    zoom_range = [1, 1.5],
)
train_aug_ds = train_aug.flow_from_directory(
    directory='./train',
    target_size=tamaño_de_imagen,
    batch_size=tamaño_de_lote,
)

Y para entrenar mi modelo hice lo siguiente:

model.fit(
    train_aug_ds,
    epochs=150,
    validation_data=(valid_aug_ds,),
)

Y funcionó. Estoy un poco confundido por cómo funciona, porque train_aug_ds es un generador, por lo que debería dar un conjunto de datos infinitamente grande. Y la documentación dice:

Cuando se pasa un conjunto de datos que se repite infinitamente, debe especificar el argumento de steps_per_epoch.

Lo cual no hice todavía, pero funciona. ¿Lo infiere de alguna manera el número de pasos? Además, ¿usa solo datos aumentados o también utiliza imágenes no aumentadas en el lote?

Entonces, básicamente, mi pregunta es cómo utilizar correctamente este generador con la función fit para tener todos los datos en mi conjunto de entrenamiento, incluyendo imágenes originales no aumentadas y aumentadas, y para pasar varias veces/pasos a través de él (actualmente parece que hace solo un paso por epoch).

Tags:  , , , ,

Answer

  1. Avatar for davy.ai

    El ImageDataGenerator() en Keras es una herramienta poderosa para aumentar datos que te permite generar un gran número de ejemplos de entrenamiento mediante la ampliación de las imágenes existentes. En este caso, has creado una instancia de ImageDataGenerator() con un conjunto de métodos de aumento predefinidos como reescalado, volteado horizontal, cambio de altura y anchura, intervalo de brillo y intervalo de zoom. Entonces, has utilizado el método flow_from_directory() para cargar las imágenes por lotes desde el directorio, y el objeto devuelto train_aug_ds es en realidad una instancia de la clase DirectoryIterator, que es un generador de Python.

    La función fit() en Keras espera una instancia de la clase Sequence o Iterator para pasar los datos al modelo, y la clase DirectoryIterator es una subclase de la clase Sequence, por lo que puedes usarla directamente para entrenar el modelo. Pero debes saber que este generador genera un flujo infinito de datos, por lo que debes establecer el número apropiado de pasos por época, es decir, el número de lotes que deben procesarse por época. Por defecto, el número de pasos se calcula como el número total de imágenes dividido por el tamaño del lote. Pero cuando se pasa un conjunto de datos que se repite infinitamente, debes especificar explícitamente el argumento steps_per_epoch.

    Para responder tu pregunta, el método flow_from_directory() utiliza por defecto tanto las imágenes originales sin aumentar como las imágenes aumentadas en cada lote, con una mezcla de aproximadamente 80:20 (original: aumentada) de imágenes. Esto se conoce como mezcla de aumento y ayuda a evitar el sobreajuste y mejorar la generalización. Pero puedes controlar la proporción de mezcla estableciendo los argumentos save_to_dir y save_prefix en el método flow_from_directory() para guardar las imágenes aumentadas en el disco y generar manualmente lotes utilizando los datos y etiquetas de imagen reales.

    Para repetir los datos varias veces/pasos, puedes establecer el argumento steps_per_epoch en un valor mayor que 1, lo que repetirá los datos el número especificado de veces por época. Por ejemplo:

    model.fit(
        train_aug_ds,
        epochs=150,
        steps_per_epoch=100, # establecido en un valor mayor que 1
        validation_data=(valid_aug_ds,),
    )
    

    Esto repetirá todos los datos de entrenamiento (originales y aumentados) 100 veces por época (un total de 100*lotes_por_epoca imágenes).

Comments are closed.