es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

Utilizar el modo de gráfico en Ray RLlib provoca errores cuando se llama a la función tf.keras.model.predict() en PPOTFPolicy.

Estoy utilizando Ray RLlib para entrenar un agente PPO con DOS modificaciones en el PPOTFPolicy.

  • Añadí una clase mixin (llamada “Recal”) al parámetro “mixins” en “build_tf_policy()”. De esta manera, el PPOTFPolicy heredaría de mi clase “Recal” y tendría acceso a las funciones miembro que definí en “Recal”. Mi clase “Recal” es una simple subclase de tf.keras.Model.

  • Definí una función “my_postprocess_fn” para reemplazar la función “compute_gae_for_sample_batch” que se proporciona al parámetro “postprocess_fn” en “build_tf_policy()”.

La función “PPOTrainer=build_trainer(…)” se mantiene sin cambios. Utilizo el framework=”tf” y desactivo el eager mode.

El código pseudo está abajo. Aquí hay una versión en ejecución en Colab.

tf.compat.v1.disable_eager_execution()

class Recal:
    def __init__(self):
        self.recal_model = build_and_compile_keras_model()

def my_postprocess_fn(policy, sample_batch):
    with policy.model.graph.as_default():
        sample_batch = policy.recal_model.predict(sample_batch)
    return compute_gae_for_sample_batch(policy, sample_batch)

PPOTFPolicy = build_tf_policy(..., postprocess_fn=my_postprocess_fn, mixins=[..., Recal])
PPOTrainer = build_trainer(...)
ppo_trainer = PPOTrainer(config=DEFAULT_CONFIG, env="CartPole-v0")

for i in range(1):
    result = ppo_trainer.train()

De esta manera, la clase “Recal” es una clase base de PPOTFPolicy, y cuando se crea una instancia del PPOTFPolicy, “Recal” se instancia dentro del mismo grafo de TensorFlow. Pero cuando se llama a my_postprocess_fn(), se produce un error (ver abajo).

tensorflow.python.framework.errors_impl.FailedPreconditionError: Could not find variable default_policy_wk1/my_model/dense/kernel. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status error message=Container localhost does not exist. (Could not find resource: localhost/default_policy_wk1/my_model/dense/kernel)
[[{{node default_policy_wk1/my_model_1/dense/MatMul/ReadVariableOp}}]]
Tags:  , , ,

Answer

  1. Avatar for davy.ai

    El mensaje de error sugiere que la variable default_policy_wk1/my_model/dense/kernel no se encuentra en el contenedor localhost. Este error se produce porque el grafo predeterminado de TensorFlow no se establece explícitamente en my_postprocess_fn().

    Para solucionar este error, podemos modificar my_postprocess_fn() para establecer el grafo predeterminado en el grafo de la instancia de PPOTFPolicy. Aquí está la versión modificada del código para my_postprocess_fn():

    def my_postprocess_fn(policy, sample_batch, other_agent=None, episode=None):
        with policy.model._build_thread_local:
            with tf.compat.v1.Graph().as_default():
                # establecer el grafo predeterminado en el grafo de la política
                tf.compat.v1.keras.backend.set_session(policy.sess)
                # recrear el modelo de política en el grafo predeterminado
                policy._initialize(policy.config["model"])
                # predecir utilizando el modelo recalculado
                recal_output = policy.model.recal_model.predict(sample_batch)
                # crear un nuevo lote para calcular GAE
                batch = policy._lazy_tensor_dict({
                    SampleBatch.CUR_OBS: sample_batch[SampleBatch.CUR_OBS],
                    SampleBatch.ACTIONS: sample_batch[SampleBatch.ACTIONS],
                    SampleBatch.REWARDS: recal_output,
                    SampleBatch.NEXT_OBS: sample_batch[SampleBatch.NEXT_OBS],
                    SampleBatch.DONES: sample_batch[SampleBatch.DONES],
                    SampleBatch.VF_PREDS: sample_batch[SampleBatch.VF_PREDS],
                    BEHAVIOUR_LOGITS: sample_batch[BEHAVIOUR_LOGITS],
                })
                # calcular GAE para el nuevo lote
                return compute_advantages(policy, batch, None)
    

    En esta versión modificada, establecemos explícitamente el grafo predeterminado utilizando tf.compat.v1.Graph().as_default() y luego establecemos la sesión en la sesión de la política utilizando tf.compat.v1.keras.backend.set_session(policy.sess). Luego recreamos el modelo de política en el grafo predeterminado y usamos el modelo recalculado para predecir las recompensas. Finalmente, creamos un nuevo lote utilizando las recompensas predichas y calculamos las ventajas utilizando compute_advantages().

Comments are closed.