es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

Reemplace BatchNorm por Conditional BatchNorm en torchvision ResNet.

Tengo una instancia de torchvision.models.ResNet y tengo mi clase CondBatchNorm2d que es un módulo similar a BatchNorm2d, pero el método forward acepta una entrada adicional y que no proviene de la capa anterior ya que es una entrada de toda la red:

def forward(self, x, y=None):
...

Sé cómo substituir cada instancia de BatchNorm2d con una instancia de CondBatchNorm2d, pero no estoy seguro de cómo escribir mi propio método forward para incluir la nueva entrada en las capas intermedias de CondBatchNorm2d. ¿Debo iterar en el método forward en los hijos de ResNet o hay una forma más adecuada de hacerlo?

Tags:  , ,

Answer

  1. Avatar for davy.ai

    Para incluir la nueva entrada y para las capas intermedias CondBatchNorm2d en la instancia torchvision.models.ResNet, puedes usar una combinación de iteración y recursión.

    Primero, puedes iterar sobre los hijos de la instancia ResNet utilizando un bucle for. Luego, puedes comprobar si cada hijo es una instancia de nn.Sequential o nn.Module. Si es una instancia de nn.Sequential, puedes llamar recursivamente a la misma función en él. Si es una instancia de nn.Module, puedes comprobar si es una instancia de nn.BatchNorm2d. Si lo es, puedes reemplazarlo por una instancia de CondBatchNorm2d y pasar la entrada adicional y a su método forward. De lo contrario, puedes llamar recursivamente a la misma función en él.

    Aquí tienes un ejemplo de implementación:

    def replace_bn_with_condbn(model, y=None):
        for name, child in model.named_children():
            if isinstance(child, nn.Sequential):
                replace_bn_with_condbn(child, y)
            elif isinstance(child, nn.Module):
                if isinstance(child, nn.BatchNorm2d):
                    setattr(model, name, CondBatchNorm2d(child.num_features, y))
                else:
                    replace_bn_with_condbn(child, y)
    

    Puedes llamar a esta función en tu instancia de ResNet con la entrada adicional y de la siguiente manera:

    resnet = torchvision.models.ResNet(...)
    replace_bn_with_condbn(resnet, y)
    

    Esto reemplazará todas las instancias de BatchNorm2d por CondBatchNorm2d y pasará la entrada adicional y a sus métodos forward.

Comments are closed.