es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

livelossplot keras plot por época

Estoy tratando de trazar una función de costo y una función de precisión en Keras.
Sin embargo, cuando quiero jugar con mis épocas, por ejemplo, 100 veces, obtendré 100 gráficos
en Python, que tengo que eliminar antes de que se presente el gráfico de la época n + 1.
Me gustaría tener una figura donde pueda ver todas las épocas de las 100 veces una vez.
mi código:

import pandas as pd
forecast_demo = pd.read_csv('forcastdemo.csv')
forecast_demo.head()

X_train = forecast_demo[['index', 'quarter']]
y_train = forecast_demo.revenue
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model
from tensorflow.keras import models

from tensorflow.keras.utils import plot_model
from tensorflow.keras.layers import Input, Dense, BatchNormalization
from IPython.core.display import Image

inputs = Input(shape=(2,))
x = BatchNormalization()(inputs)
x = Dense(512, activation='relu')(x)
x = Dense(128, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(32, activation='relu')(x)
outputs = Dense(1)(x)

model = Model(inputs, outputs)
model.summary()
keras.utils.plot_model(model, to_file='images/oefening2.png', show_shapes=True)
Image('images/oefening2.png')

model.compile(
    loss=keras.losses.mean_squared_error, # de keras.losses.mean_absolute_percentage_error, keras.losses.mean_absolute_error (para regresión)
    optimizer=keras.optimizers.Adam(),
    metrics=keras.metrics.mean_absolute_percentage_error # de keras.metrics.mean_absolute_error, keras.metrics.mean_squared_error
)

from livelossplot import PlotLossesKeras

Aquí obtendré 100 gráficos.

history = model.fit(X_train, y_train,
                    batch_size=20, 
                    epochs=100,
                    callbacks=[PlotLossesKeras()],
                    verbose=False)

Tags:  , , , ,

Answer

  1. Avatar for davy.ai

    Para graficar la función de costo y la función de precisión para todas las 100 épocas en una sola figura, puedes modificar la función de retroalimentación PlotLossesKeras() utilizada en el método model.fit(). En lugar de crear un nuevo gráfico para cada época, puedes actualizar el gráfico existente con los nuevos valores para cada época.

    Para hacer esto, puedes crear una instancia de la función de retroalimentación PlotLossesKeras() fuera del método model.fit(), y pasar esta instancia como una retroalimentación al método model.fit(). Luego, puedes pasar el parámetro logs al método on_epoch_end() de la instancia de PlotLossesKeras(), lo cual actualizará el gráfico existente con los nuevos valores para cada época.

    Aquí tienes un ejemplo:

    import pandas as pd
    import tensorflow as tf
    from tensorflow.keras import Model
    from tensorflow.keras.layers import Input, Dense, BatchNormalization
    from tensorflow.keras.utils import plot_model
    from livelossplot import PlotLossesKeras
    
    # Cargar datos
    forecast_demo = pd.read_csv('forcastdemo.csv')
    X_train = forecast_demo[['index', 'quarter']]
    y_train = forecast_demo.revenue
    
    # Definir modelo
    inputs = Input(shape=(2,))
    x = BatchNormalization()(inputs)
    x = Dense(512, activation='relu')(x)
    x = Dense(128, activation='relu')(x)
    x = Dense(64, activation='relu')(x)
    x = Dense(32, activation='relu')(x)
    outputs = Dense(1)(x)
    
    model = Model(inputs, outputs)
    model.compile(
        loss=tf.keras.losses.mean_squared_error,
        optimizer=tf.keras.optimizers.Adam(),
        metrics=[tf.keras.metrics.mean_absolute_percentage_error]
    )
    
    # Graficar pérdidas
    plot_losses = PlotLossesKeras()
    
    # Entrenar modelo
    history = model.fit(
        X_train,
        y_train,
        batch_size=20,
        epochs=100,
        verbose=False,
        callbacks=[plot_losses],
    )
    
    # Mostrar gráfico final
    plot_losses.draw()
    

    En este ejemplo, creamos una instancia de la función de retroalimentación PlotLossesKeras() y la pasamos como una retroalimentación al método model.fit(). Luego, dentro del método model.fit(), entrenamos el modelo durante 100 épocas y pasamos el parámetro logs al método on_epoch_end() de la instancia plot_losses. Esto actualiza el gráfico existente con los nuevos valores para cada época.

    Finalmente, llamamos al método draw() de la instancia plot_losses para mostrar el gráfico final. Este gráfico mostrará la función de costo y la función de precisión para todas las 100 épocas en una sola figura.

Comments are closed.