¿Cuál es la alternativa de elementwise_grad de autograd en JAX?
Quiero resolver una ecuación diferencial de segundo orden con una red neuronal. Para la diferenciación automática estoy utilizando la biblioteca JAX. Para calcular la primera y segunda derivada de mi variable objetivo ‘u’, es decir, calcular du/dx y d2u/dx2, se ha utilizado elementwise_grad en un ejemplo. ¿Cuál es su alternativa en JAX?
Por ejemplo, la función de la red neuronal está evaluando ‘u’, que se define de la siguiente manera:
def u(params, inputs):
for Weights, biases in params:
outputs = np.dot(inputs, Weights) + biases
inputs = sigmoid(outputs)
return outputs
u tiene dos argumentos: params es el conjunto de pesos y sesgos, e inputs es el rango x con respecto al cual se diferenciará ‘u’.
Supongamos que x tiene una longitud de 50, por lo que el tamaño de la salida u también será 50*1.
Ahora tengo que tomar la diferenciación de los 50 valores de u al mismo tiempo. ¿Qué funciones debo usar en JAX para calcular du/dx y d2u/dx2? La función grad no está funcionando.
dudx = grad(u,1)(x)
d2udx2 = grad(grad(u,1)(x))(x)
Estas están dando algunos errores.
davy.ai
Para calcular las derivadas de primer y segundo orden de la variable objetivo ‘u’ elemento por elemento en JAX, podemos utilizar las funciones
jax.jacobian
yjax.hessian
respectivamente.Entonces, el código alternativo sería:
Aquí,
argnums=1
especifica que queremos tomar la derivada con respecto al segundo argumento de la funciónu
, es decir,inputs
.Con
jax.jacobian
, podemos obtener la matriz jacobiana de forma(50,50)
, que contiene las derivadas de primer orden para los 50 valores deu
de una vez.De manera similar,
jax.hessian
calcula la matriz Hessiana de forma(50,50)
, que contiene las derivadas de segundo orden para los 50 valores deu
de una vez.Utilizando estas funciones, podemos evitar los errores encontrados al utilizar la función
grad
.