es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

¿Cómo puedo concatenar tensores o listas de PyTorch en una configuración distribuida de múltiples nodos?

Estoy tratando de implementar algo como esto para 2 nodos (cada nodo con 2 GPUs):

Proceso paralelo iniciado con torch.distributed.init_process_group ()

Todas las GPUs trabajan en paralelo y generan listas como:

[20, 0, 1, 17] para GPU0 del nodo A
[1, 2, 3, 4] para GPU1 del nodo A
[5, 6, 7, 8] para GPU0 del nodo B
[0, 2, 4, 6] para GPU1 del nodo B

Intenté

torch.distributed.reduce()

para obtener una suma de estos 4:

[26, 10, 15, 35]

Pero lo que quiero es una versión concatenada como esta

[[20, 0, 1, 17], [1, 2, 3, 4], [5, 6, 7, 8], [0, 2, 4, 6]]

O

[20, 0, 1, 17, 1, 2, 3, 4, 5, 6, 7, 8, 0, 2, 4, 6]

también está bien para mí.

¿Es posible lograr esto con torch.distributed?

Tags:  , , , ,

Answer

  1. Avatar for davy.ai

    Desafortunadamente, torch.distributed.reduce() no es adecuado para concatenar estas listas. En su lugar, puede usar torch.distributed.gather() para recopilar las listas de todos los nodos, luego concatenarlas manualmente utilizando las funciones de concatenación incorporadas de Python. Aquí hay una implementación de ejemplo:

    import torch.distributed as dist
    
    # Cada nodo genera una lista de valores
    local_list = [20, 0, 1, 17] if dist.get_rank() == 0 else [1, 2, 3, 4]
    local_list += [5, 6, 7, 8] if dist.get_rank() == 1 else [0, 2, 4, 6]
    
    # Recopila todas las listas en el nodo raíz (rango 0)
    all_lists = [torch.zeros_like(torch.tensor(local_list)) for i in range(dist.get_world_size())]
    dist.gather(tensor=torch.tensor(local_list), gather_list=all_lists, dst=0)
    
    # Concatena las listas en el nodo raíz
    if dist.get_rank() == 0:
        concatenated_lists = []
        for l in all_lists:
            concatenated_lists.append(l.tolist())
        concatenated_lists = sum(concatenated_lists, [])  # Aplana la lista de listas en una sola lista
        print(concatenated_lists)
    

    En este ejemplo, usamos torch.distributed.gather() para recopilar todas las listas en el nodo raíz (rango 0), luego aplanamos la lista de listas en una sola lista utilizando la función sum() de Python. La variable concatenated_lists resultante contiene la versión concatenada deseada de las listas originales.

Comments are closed.