Definir una versión alternativa de una función de Python utilizando bibliotecas diferentes / variables globales.
Estoy utilizando tanto Python numpy
como el reemplazo de numpy
, jax.numpy
, y a menudo necesito una función definida por el usuario para ambos. El problema es que jax.numpy
es un reemplazo tan bueno que muy a menudo estoy escribiendo básicamente la misma función para ambas bibliotecas. Así que para dar un ejemplo muy simple, podría tener:
module_numpy.py
:
import numpy as np
def get_loss(x):
return np.square(x) + 1
module_jax.py
:
import jax.numpy as jnp
def get_loss(x):
return jnp.square(x) + 1
donde básicamente estoy escribiendo el mismo código dos veces, y el problema es que si hago un cambio en una de estas funciones, tengo que hacer el mismo cambio en ambas, y se vuelve difícil de mantener. Puede que vuelva a las funciones que he cambiado más tarde y no me dé cuenta de cuál es la más “correcta” o la más actualizada.
Pregunta: ¿Cómo puedo redefinir una biblioteca esencialmente reemplazando las variables globales bajo las cuales se define module_numpy.get_loss
? Por ejemplo, he intentado:
module_jax_auto.py
:
import jax.numpy as jnp
import module_numpy
import inspect
f = eval(inspect.source(module_numpy.get_loss).replace('np', 'jnp'))
pero Python se queja con un error de sintaxis. Aparentemente no puedo usar def
dentro de eval
, pero incluso si pudiera, no estoy seguro de que esta sea la mejor manera de solucionar el problema de mantener dos bibliotecas paralelas. Pero siempre quiero usar la versión de numpy
para definir la versión de jax
, y no al revés.
Vale la pena mencionar que hay algunos casos en los que necesito definir la versión de jax
de la función de manera diferente, por lo que excluiré algunas funciones de ser “generadas automáticamente” a partir de la versión de numpy
.
¿Cuál es el mejor enfoque aquí?
davy.ai
Un enfoque sería definir las funciones en un módulo separado e importar ese módulo tanto en
module_numpy
como enmodule_jax
. De esta manera, solo tienes que mantener las funciones en un solo lugar y cualquier cambio realizado en el módulo se verá reflejado tanto ennumpy
como enjax.numpy
. Por ejemplo,module_functions.py:
module_numpy.py:
module_jax.py:
Este enfoque asegura que solo tengas que cambiar las funciones en
module_functions.py
y no enmodule_numpy.py
ymodule_jax.py
. También te permite excluir ciertas funciones del proceso de auto-generación si es necesario.