es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

PyTorch no puede hacer pickle a lambda.

Tengo un modelo que utiliza una LambdaLayer personalizada de la siguiente manera:

class LambdaLayer(LightningModule):
    def __init__(self, fun):
        super(LambdaLayer, self).__init__()
        self.fun = fun

    def forward(self, x):
        return self.fun(x)

class TorchCatEmbedding(LightningModule):
    def __init__(self, start, end):
        super(TorchCatEmbedding, self).__init__()
        self.lb = LambdaLayer(lambda x: x[:, start:end])
        self.embedding = torch.nn.Embedding(50, 5)

    def forward(self, inputs):
        o = self.lb(inputs).to(torch.int32)
        o = self.embedding(o)
        return o.squeeze()

El modelo funciona perfectamente en CPU o 1 GPU. Sin embargo, al ejecutarlo con PyTorch Lightning en 2+ GPUs, se produce el siguiente error:

AttributeError: No se puede serializar el objeto local 'TorchCatEmbedding.__init__.<locals>.<lambda>'

El propósito de utilizar una función lambda aquí es que, dado un tensor de inputs, quiero pasar solo inputs[:, start:end] a la capa de embedding.

Mis preguntas son:

  • ¿Existe una alternativa para utilizar una función lambda en este caso?
  • Si no es así, ¿qué se debe hacer para que la función lambda funcione en este contexto?
Tags:  , , ,

Answer

  1. Avatar for davy.ai

    La razón de este error es que LambdaLayer no es serializable por la librería pickle de Python, la cual es utilizada por PyTorch Lightning para serializar y transferir objetos a múltiples GPUs. Una alternativa para utilizar lambda en este caso es reemplazarlo por un torch.nn.Module personalizado que realice la misma operación.

    Aquí tienes un ejemplo de cómo hacerlo:

    class SliceLayer(torch.nn.Module):
        def __init__(self, start, end):
            super(SliceLayer, self).__init__()
            self.start = start
            self.end = end
    
        def forward(self, x):
            return x[:, self.start:self.end]
    
    class TorchCatEmbedding(LightningModule):
        def __init__(self, start, end):
            super(TorchCatEmbedding, self).__init__()
            self.lb = SliceLayer(start, end)
            self.embedding = torch.nn.Embedding(50, 5)
    
        def forward(self, inputs):
            o = self.lb(inputs).to(torch.int32)
            o = self.embedding(o)
            return o.squeeze()
    

    Utilizando este enfoque, la clase SliceLayer es serializable por pickle y el modelo debería poder ejecutarse en múltiples GPUs sin errores.

Comments are closed.