es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

Cómo utilizar LSTM de Keras en datos de entrada de alta dimensión (video)

Tengo datos x (muestra, fotogramas, tamañoX, tamañoY, canalRGB) de dimensiones (90, 10, 480, 640, 3) que representan un conjunto de datos de videos. Estoy intentando aplicar LSTM en Keras para clasificarlos. Según mi comprensión, en input_shape del LSTM, pasamos timesteps y features, que si entiendo correctamente, deberían ser timesteps=10 y features=3, es decir, 480, 640, 3.

Los datos de salida y tienen dimensiones (90, 1) que representan un valor de 1 o 0. A continuación, se muestra el modelo que estoy entrenando:

import tensorflow as tf
import numpy as np
from tensorflow import keras

model = tf.keras.Sequential()
model.add(tf.keras.layers.LSTM(units=32, input_shape=(10, 3), activation="relu"))
model.add(tf.keras.layers.Dense(8, activation="relu"))
model.add(tf.keras.layers.Dense(1, activation="sigmoid"))

model.compile(loss='binary_crossentropy', optimizer='rmsprop')

model.fit(x, y, epochs=500)

Pero está arrojando un error en model.fit:

Error al comprobar la entrada: se esperaba que lstm_16_input tuviera 3 dimensiones, pero se obtuvo una matriz de forma (90, 10, 480, 640, 3)
Tags:  , , , ,

Answer

  1. Avatar for davy.ai

    El error sugiere que los datos de entrada x tienen la forma incorrecta. Actualmente, tienen 5 dimensiones (muestras, fotogramas, tamañoX, tamañoY, canalRGB). Sin embargo, la capa LSTM espera una entrada tridimensional de (muestras, pasos de tiempo, características).

    Para resolver el problema, debemos remodelar los datos de entrada x a una forma tridimensional de (90, 10, 480*640*3). Esto se puede lograr utilizando la función reshape() de la siguiente manera:

    x_reshape = x.reshape(90, 10, -1)
    

    La función reshape() calcula automáticamente la dimensión restante como -1 dados las otras dimensiones.

    Luego podemos utilizar x_reshape como los datos de entrada en la función fit() de la siguiente manera:

    model.fit(x_reshape, y, epochs=500)
    

    Esto debería resolver el error y permitir entrenar el modelo con los datos de entrada.

Comments are closed.