Obtener entrada de la capa completamente conectada del modelo ResNet durante el tiempo de ejecución.
Encontré una solución, la dejé como respuesta a esta pregunta a continuación.
Información sobre el proyecto: Tarea de clasificación con 2 clases.
Estoy tratando de obtener la salida de la capa completamente conectada de mi modelo para cada imagen que ingreso en el modelo durante la ejecución. Planeo usarlos después de que el modelo haya terminado de entrenar o probar todas las imágenes para visualizar con UMAP.
El modelo:
# Cargar resnet
def obtener_modelo():
modelo = torchvision.models.resnet50(pretrained=True)
num_ftrs = modelo.fc.in_features
modelo.fc = nn.Linear(num_ftrs, 2)
return modelo
La parte relevante del módulo pl:
class ModeloClasificador(pl.LightningModule):
def __init__(self, modelo):
super().__init__()
self.modelo = modelo
self.tasa_aprendizaje = 0.0001
def paso_entrenamiento(self, lote, lote_idx):
x = lote['imagen']
y = lote['objetivos']
x_hat = self.modelo(x)
salida = nn.CrossEntropyLoss()
perdida = salida(x_hat, y)
return perdida
def paso_prueba(self, lote, lote_idx):
x = lote['imagen']
y = lote['objetivos']
x_hat = self.modelo(x)
¿Es posible hacer esto agregando una lista vacía a la inicialización del módulo pl y luego agregar la salida después de que se ejecute x_hat = modelo(x)
?
¿Cómo sabría si después de que se ejecute x_hat = modelo(x)
, los out_features no se eliminan/descartan inmediatamente?
davy.ai
Sí, es posible agregar una lista vacía en el método init del módulo pl y luego agregar la salida después de que se ejecute
x_hat = model(x)
. Aquí tienes un ejemplo de cómo se puede hacer:__init__
declassifierModel
:test_step
, agrega la salida de la capa totalmente conectada a la listafc_outputs
:fc_outputs
contendrá la salida de la capa totalmente conectada para cada imagen. Luego puedes usar esta lista con UMAP para visualizar los datos.En cuanto a tu segunda pregunta, la salida de la capa totalmente conectada no debería ser eliminada/descartada inmediatamente después de que se ejecute
x_hat = model(x)
, siempre y cuando la agregues a una lista o la utilices de alguna otra manera. Sin embargo, si simplemente la imprimes o la asignas a una variable sin utilizarla posteriormente, puede ser recolectada por el recolector de basura y eliminada.