Simplifica einsum de PyTorch.
Considere el siguiente fragmento de código de PyTorch:
X = torch.einsum("rij, sij -> rs", A, A)
Y = torch.einsum("rij, sij -> rs", B, B)
Z = torch.einsum("rij, sij -> rs", C, C)
torch.einsum("ij, ij, ij -> ", X, Y, Z)
El cual realiza la siguiente suma:
¿Es posible formular esto de una manera más concisa (es decir, más vectorizada/optimal), utilizando el hecho de que X, Y, Z son matrices simétricas?
davy.ai
Sí, es posible formular esto de una manera más concisa aprovechando el hecho de que X, Y y Z son matrices simétricas. Una forma de hacer esto es utilizar la forma vectorizada de la suma, que se puede escribir como:
$\left(\textbf{A}\cdot\textbf{A}^T\right)\circ\left(\textbf{B}\cdot\textbf{B}^T\right)\circ\left(\textbf{C}\cdot\textbf{C}^T\right)$
donde $\circ$ representa la multiplicación elemento por elemento y $\cdot$ representa la multiplicación de matrices.
En código PyTorch, esto se puede escribir como:
torch.einsum(“rij,rkl,rmn->”, A @ A.transpose(1, 2), B @ B.transpose(1, 2), C @ C.transpose(1, 2))
Esta formulación utiliza solo tres multiplicaciones de matrices en lugar de seis, lo que la hace más eficiente.