PyTorch no puede hacer pickle a lambda.
Tengo un modelo que utiliza una LambdaLayer
personalizada de la siguiente manera:
class LambdaLayer(LightningModule):
def __init__(self, fun):
super(LambdaLayer, self).__init__()
self.fun = fun
def forward(self, x):
return self.fun(x)
class TorchCatEmbedding(LightningModule):
def __init__(self, start, end):
super(TorchCatEmbedding, self).__init__()
self.lb = LambdaLayer(lambda x: x[:, start:end])
self.embedding = torch.nn.Embedding(50, 5)
def forward(self, inputs):
o = self.lb(inputs).to(torch.int32)
o = self.embedding(o)
return o.squeeze()
El modelo funciona perfectamente en CPU o 1 GPU. Sin embargo, al ejecutarlo con PyTorch Lightning en 2+ GPUs, se produce el siguiente error:
AttributeError: No se puede serializar el objeto local 'TorchCatEmbedding.__init__.<locals>.<lambda>'
El propósito de utilizar una función lambda aquí es que, dado un tensor de inputs
, quiero pasar solo inputs[:, start:end]
a la capa de embedding
.
Mis preguntas son:
- ¿Existe una alternativa para utilizar una función lambda en este caso?
- Si no es así, ¿qué se debe hacer para que la función lambda funcione en este contexto?
davy.ai
La razón de este error es que
LambdaLayer
no es serializable por la libreríapickle
de Python, la cual es utilizada por PyTorch Lightning para serializar y transferir objetos a múltiples GPUs. Una alternativa para utilizar lambda en este caso es reemplazarlo por untorch.nn.Module
personalizado que realice la misma operación.Aquí tienes un ejemplo de cómo hacerlo:
Utilizando este enfoque, la clase
SliceLayer
es serializable porpickle
y el modelo debería poder ejecutarse en múltiples GPUs sin errores.