Utilizar el modo de gráfico en Ray RLlib provoca errores cuando se llama a la función tf.keras.model.predict() en PPOTFPolicy.
Estoy utilizando Ray RLlib para entrenar un agente PPO con DOS modificaciones en el PPOTFPolicy.
- Añadí una clase mixin (llamada “Recal”) al parámetro “mixins” en “build_tf_policy()”. De esta manera, el PPOTFPolicy heredaría de mi clase “Recal” y tendría acceso a las funciones miembro que definí en “Recal”. Mi clase “Recal” es una simple subclase de tf.keras.Model.
-
Definí una función “my_postprocess_fn” para reemplazar la función “compute_gae_for_sample_batch” que se proporciona al parámetro “postprocess_fn” en “build_tf_policy()”.
La función “PPOTrainer=build_trainer(…)” se mantiene sin cambios. Utilizo el framework=”tf” y desactivo el eager mode.
El código pseudo está abajo. Aquí hay una versión en ejecución en Colab.
tf.compat.v1.disable_eager_execution()
class Recal:
def __init__(self):
self.recal_model = build_and_compile_keras_model()
def my_postprocess_fn(policy, sample_batch):
with policy.model.graph.as_default():
sample_batch = policy.recal_model.predict(sample_batch)
return compute_gae_for_sample_batch(policy, sample_batch)
PPOTFPolicy = build_tf_policy(..., postprocess_fn=my_postprocess_fn, mixins=[..., Recal])
PPOTrainer = build_trainer(...)
ppo_trainer = PPOTrainer(config=DEFAULT_CONFIG, env="CartPole-v0")
for i in range(1):
result = ppo_trainer.train()
De esta manera, la clase “Recal” es una clase base de PPOTFPolicy, y cuando se crea una instancia del PPOTFPolicy, “Recal” se instancia dentro del mismo grafo de TensorFlow. Pero cuando se llama a my_postprocess_fn(), se produce un error (ver abajo).
tensorflow.python.framework.errors_impl.FailedPreconditionError: Could not find variable default_policy_wk1/my_model/dense/kernel. This could mean that the variable has been deleted. In TF1, it can also mean the variable is uninitialized. Debug info: container=localhost, status error message=Container localhost does not exist. (Could not find resource: localhost/default_policy_wk1/my_model/dense/kernel)
[[{{node default_policy_wk1/my_model_1/dense/MatMul/ReadVariableOp}}]]
davy.ai
El mensaje de error sugiere que la variable
default_policy_wk1/my_model/dense/kernel
no se encuentra en el contenedorlocalhost
. Este error se produce porque el grafo predeterminado de TensorFlow no se establece explícitamente enmy_postprocess_fn()
.Para solucionar este error, podemos modificar
my_postprocess_fn()
para establecer el grafo predeterminado en el grafo de la instancia dePPOTFPolicy
. Aquí está la versión modificada del código paramy_postprocess_fn()
:En esta versión modificada, establecemos explícitamente el grafo predeterminado utilizando
tf.compat.v1.Graph().as_default()
y luego establecemos la sesión en la sesión de la política utilizandotf.compat.v1.keras.backend.set_session(policy.sess)
. Luego recreamos el modelo de política en el grafo predeterminado y usamos el modelo recalculado para predecir las recompensas. Finalmente, creamos un nuevo lote utilizando las recompensas predichas y calculamos las ventajas utilizandocompute_advantages()
.