es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

Definir una versión alternativa de una función de Python utilizando bibliotecas diferentes / variables globales.

Estoy utilizando tanto Python numpy como el reemplazo de numpy, jax.numpy, y a menudo necesito una función definida por el usuario para ambos. El problema es que jax.numpy es un reemplazo tan bueno que muy a menudo estoy escribiendo básicamente la misma función para ambas bibliotecas. Así que para dar un ejemplo muy simple, podría tener:

module_numpy.py:

import numpy as np

def get_loss(x):
    return np.square(x) + 1

module_jax.py:

import jax.numpy as jnp

def get_loss(x):
    return jnp.square(x) + 1

donde básicamente estoy escribiendo el mismo código dos veces, y el problema es que si hago un cambio en una de estas funciones, tengo que hacer el mismo cambio en ambas, y se vuelve difícil de mantener. Puede que vuelva a las funciones que he cambiado más tarde y no me dé cuenta de cuál es la más “correcta” o la más actualizada.

Pregunta: ¿Cómo puedo redefinir una biblioteca esencialmente reemplazando las variables globales bajo las cuales se define module_numpy.get_loss? Por ejemplo, he intentado:

module_jax_auto.py:

import jax.numpy as jnp
import module_numpy
import inspect

f = eval(inspect.source(module_numpy.get_loss).replace('np', 'jnp'))

pero Python se queja con un error de sintaxis. Aparentemente no puedo usar def dentro de eval, pero incluso si pudiera, no estoy seguro de que esta sea la mejor manera de solucionar el problema de mantener dos bibliotecas paralelas. Pero siempre quiero usar la versión de numpy para definir la versión de jax, y no al revés.

Vale la pena mencionar que hay algunos casos en los que necesito definir la versión de jax de la función de manera diferente, por lo que excluiré algunas funciones de ser “generadas automáticamente” a partir de la versión de numpy.

¿Cuál es el mejor enfoque aquí?

Tags:  , , ,

Answer

  1. Avatar for davy.ai

    Un enfoque sería definir las funciones en un módulo separado e importar ese módulo tanto en module_numpy como en module_jax. De esta manera, solo tienes que mantener las funciones en un solo lugar y cualquier cambio realizado en el módulo se verá reflejado tanto en numpy como en jax.numpy. Por ejemplo,

    module_functions.py:

    import numpy as np
    
    def get_loss(x):
        return np.square(x) + 1
    

    module_numpy.py:

    from module_functions import get_loss
    import numpy as np
    
    ...
    

    module_jax.py:

    from module_functions import get_loss
    import jax.numpy as jnp
    
    ...
    

    Este enfoque asegura que solo tengas que cambiar las funciones en module_functions.py y no en module_numpy.py y module_jax.py. También te permite excluir ciertas funciones del proceso de auto-generación si es necesario.

Comments are closed.