es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

TensorFlow: ¿Cómo extraer los attention_scores para graficar?

Si tienes una capa MultiHeadAttention en Keras, entonces puede devolver los puntajes de atención de la siguiente manera:

    x, attention_scores = MultiHeadAttention(1, 10, 10)(x, return_attention_scores=True)

¿Cómo extraes los puntajes de atención del gráfico de red? Me gustaría graficarlos.

Tags:  , , , ,

Answer

  1. Avatar for davy.ai

    Para extraer los puntajes de atención de la capa MultiHeadAttention en Keras, puedes usar las funciones tf.transpose y tf.reshape para remodelar la salida de la capa.

    # importar las bibliotecas necesarias
    import tensorflow as tf
    from tensorflow import keras
    from keras.layers import MultiHeadAttention, Input
    
    # definir la forma de entrada
    input_shape = (10, 10)
    
    # definir la capa de entrada
    inputs = Input(shape=input_shape)
    
    # definir la capa de multiheadattention
    mha_layer = MultiHeadAttention(n_heads=1, key_dim=10, value_dim=10)
    
    # pasar las entradas a través de la capa y obtener attention_scores
    mha_output, attention_scores = mha_layer(inputs, return_attention_scores=True)
    
    # definir el modelo
    model = keras.models.Model(inputs=inputs, outputs=[mha_output, attention_scores])
    
    # compilar el modelo
    model.compile(optimizer='adam', loss='mse')
    
    # entrenar el modelo
    history = model.fit(inputs, [mha_output, attention_scores], epochs=10)
    
    # extraer los puntajes de atención del grafo de la red
    attention_scores = tf.transpose(tf.reshape(attention_scores, [-1, 10]))
    

    En el código anterior, definimos la capa de entrada, la capa de MultiHeadAttention y el modelo final. Luego, compilamos y entrenamos el modelo. Finalmente, usamos tf.transpose y tf.reshape para remodelar los puntajes de atención y obtenerlos como un tensor.

    Una vez que hayas extraído los puntajes de atención, puedes usar cualquier biblioteca de gráficos (por ejemplo, matplotlib) para trazarlos.

Comments are closed.