¿Cómo puedo concatenar tensores o listas de PyTorch en una configuración distribuida de múltiples nodos?
Estoy tratando de implementar algo como esto para 2 nodos (cada nodo con 2 GPUs):
Proceso paralelo iniciado con torch.distributed.init_process_group ()
Todas las GPUs trabajan en paralelo y generan listas como:
[20, 0, 1, 17] para GPU0 del nodo A
[1, 2, 3, 4] para GPU1 del nodo A
[5, 6, 7, 8] para GPU0 del nodo B
[0, 2, 4, 6] para GPU1 del nodo B
Intenté
torch.distributed.reduce()
para obtener una suma de estos 4:
[26, 10, 15, 35]
Pero lo que quiero es una versión concatenada como esta
[[20, 0, 1, 17], [1, 2, 3, 4], [5, 6, 7, 8], [0, 2, 4, 6]]
O
[20, 0, 1, 17, 1, 2, 3, 4, 5, 6, 7, 8, 0, 2, 4, 6]
también está bien para mí.
¿Es posible lograr esto con torch.distributed?
davy.ai
Desafortunadamente, torch.distributed.reduce() no es adecuado para concatenar estas listas. En su lugar, puede usar torch.distributed.gather() para recopilar las listas de todos los nodos, luego concatenarlas manualmente utilizando las funciones de concatenación incorporadas de Python. Aquí hay una implementación de ejemplo:
En este ejemplo, usamos torch.distributed.gather() para recopilar todas las listas en el nodo raíz (rango 0), luego aplanamos la lista de listas en una sola lista utilizando la función sum() de Python. La variable concatenated_lists resultante contiene la versión concatenada deseada de las listas originales.