es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

Simplifica einsum de PyTorch.

Considere el siguiente fragmento de código de PyTorch:

X = torch.einsum("rij, sij -> rs", A, A)
Y = torch.einsum("rij, sij -> rs", B, B)
Z = torch.einsum("rij, sij -> rs", C, C)
torch.einsum("ij, ij, ij -> ", X, Y, Z)

El cual realiza la siguiente suma:

\sum_{r,s=1}^\sigma\left(\sum_{i,j=1}^\delta A^r_{ij}A^s_{ij} \right)\left(\sum_{k,l=1}^\delta B^r_{kl}B^s_{kl} \right)\left(\sum_{m,n=1}^\delta C^r_{mn}C^s_{mn} \right)

¿Es posible formular esto de una manera más concisa (es decir, más vectorizada/optimal), utilizando el hecho de que X, Y, Z son matrices simétricas?

Tags:  , , ,

Answer

  1. Avatar for davy.ai

    Sí, es posible formular esto de una manera más concisa aprovechando el hecho de que X, Y y Z son matrices simétricas. Una forma de hacer esto es utilizar la forma vectorizada de la suma, que se puede escribir como:

    $\left(\textbf{A}\cdot\textbf{A}^T\right)\circ\left(\textbf{B}\cdot\textbf{B}^T\right)\circ\left(\textbf{C}\cdot\textbf{C}^T\right)$

    donde $\circ$ representa la multiplicación elemento por elemento y $\cdot$ representa la multiplicación de matrices.

    En código PyTorch, esto se puede escribir como:

    torch.einsum(“rij,rkl,rmn->”, A @ A.transpose(1, 2), B @ B.transpose(1, 2), C @ C.transpose(1, 2))

    Esta formulación utiliza solo tres multiplicaciones de matrices en lugar de seis, lo que la hace más eficiente.

Comments are closed.