es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

¿Cuál es la alternativa de elementwise_grad de autograd en JAX?

Quiero resolver una ecuación diferencial de segundo orden con una red neuronal. Para la diferenciación automática estoy utilizando la biblioteca JAX. Para calcular la primera y segunda derivada de mi variable objetivo ‘u’, es decir, calcular du/dx y d2u/dx2, se ha utilizado elementwise_grad en un ejemplo. ¿Cuál es su alternativa en JAX?

Por ejemplo, la función de la red neuronal está evaluando ‘u’, que se define de la siguiente manera:

def u(params, inputs):
    for Weights, biases in params:
        outputs = np.dot(inputs, Weights) + biases
        inputs = sigmoid(outputs)    
    return outputs

u tiene dos argumentos: params es el conjunto de pesos y sesgos, e inputs es el rango x con respecto al cual se diferenciará ‘u’.

Supongamos que x tiene una longitud de 50, por lo que el tamaño de la salida u también será 50*1.

Ahora tengo que tomar la diferenciación de los 50 valores de u al mismo tiempo. ¿Qué funciones debo usar en JAX para calcular du/dx y d2u/dx2? La función grad no está funcionando.

dudx = grad(u,1)(x)
d2udx2 = grad(grad(u,1)(x))(x) 

Estas están dando algunos errores.

Tags:  , , ,

Answer

  1. Avatar for davy.ai

    Para calcular las derivadas de primer y segundo orden de la variable objetivo ‘u’ elemento por elemento en JAX, podemos utilizar las funciones jax.jacobian y jax.hessian respectivamente.

    Entonces, el código alternativo sería:

    dudx = jax.jacobian(u, argnums=1)(params, inputs)
    d2udx2 = jax.hessian(u, argnums=1)(params, inputs)
    

    Aquí, argnums=1 especifica que queremos tomar la derivada con respecto al segundo argumento de la función u, es decir, inputs.

    Con jax.jacobian, podemos obtener la matriz jacobiana de forma (50,50), que contiene las derivadas de primer orden para los 50 valores de u de una vez.

    De manera similar, jax.hessian calcula la matriz Hessiana de forma (50,50), que contiene las derivadas de segundo orden para los 50 valores de u de una vez.

    Utilizando estas funciones, podemos evitar los errores encontrados al utilizar la función grad.

Comments are closed.