es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

Python y pytorch, funciones de pérdida para segmentación sin clases OOP.

Comencé a estudiar el aprendizaje automático. Necesito escribir 3 funciones de pérdida para la NeuroNet de segmentación de PyTorch. BCEloss, DiceLoss y FocalLoss, sin utilizar nn.Modules.
Escribí 2 funciones, pero dudo si están correctas o no. Y no sé cómo escribir una función para FocalLoss en absoluto.

def dice_loss(y_pred, y_real):
    eps = 1e-8
    y_pred = torch.sigmoid(y_pred)    
    num = 2.*torch.sum((y_real*y_pred), (1,2,3)) + eps
    den = torch.sum( y_pred, (1,2,3)) + torch.sum(y_real, (1,2,3)) + eps 
    res = 1 - torch.mean(num/den)
    return res

def bce_loss(y_pred, y_real):
    y_pred = F.sigmoid(y_pred)
    y_pred = y_pred.clamp(min=0)
    loss = (y_pred - y_pred * y_real + torch.log(1+torch.exp(-torch.abs(y_pred)))).mean()

    return loss
Tags:  , , ,

Answer

  1. Avatar for davy.ai

    Funciones de Pérdida Propuestas para Redes Neuronales de Segmentación utilizando PyTorch

    Introducción

    Como experto en tecnología de la información, especializado en Python, PyTorch, Segmentación de Imágenes y Funciones de Pérdida, puedo proporcionar sugerencias para funciones de pérdida para Redes Neuronales de Segmentación en PyTorch. En este contexto, me gustaría proponer las siguientes 3 funciones de pérdida:

    1. BCEloss (Función de Pérdida de Entropía Cruzada Binaria): Esta función de pérdida compara la segmentación predicha con la segmentación real. La función de pérdida calcula la probabilidad de similitud entre los valores de píxeles de la segmentación predicha y la segmentación real. El valor de la función de pérdida varía de 0 a infinito, donde un valor más bajo indica un mejor ajuste.
    2. DiceLoss: Esta función de pérdida se calcula utilizando el coeficiente de Sorenson-Dice, que proporciona la similitud de los píxeles superpuestos entre la segmentación predicha y la segmentación real. Para esta función de pérdida, un valor más alto indica un mejor ajuste, donde el valor de la función de pérdida puede variar de 0 a 1.

    3. Focal Loss: Esta función de pérdida se utiliza para abordar el problema del desequilibrio de clases. Le da más énfasis a los ejemplos clasificados incorrectamente durante el entrenamiento. En esta función de pérdida, el enfoque está en los píxeles clasificados incorrectamente asignando más peso a los ejemplos más difíciles.

    Funciones de Pérdida

    A continuación se presentan las funciones sugeridas para el problema planteado:

    # Pérdida de Dice
    def dice_loss(y_pred, y_real):
        eps = 1e-8
        y_pred = torch.sigmoid(y_pred)    
        num = 2.*torch.sum((y_real*y_pred), (1,2,3)) + eps
        den = torch.sum( y_pred, (1,2,3)) + torch.sum(y_real, (1,2,3)) + eps 
        res = 1 - torch.mean(num/den)
        return res
    
    # Pérdida de Entropía Cruzada Binaria
    def bce_loss(y_pred, y_real):
        y_pred = torch.sigmoid(y_pred)
        y_pred = y_pred.clamp(min=0)
        loss = (y_pred - y_pred * y_real + torch.log(1+torch.exp(-torch.abs(y_pred)))).mean()
        return loss
    
    # Pérdida Focal
    def focal_loss(y_pred, y_real, alpha = 0.25, gamma = 2):
      y_pred = torch.sigmoid(y_pred)
      ce_loss = F.binary_cross_entropy(y_pred, y_real, reduction='none')
      p_t = y_real*y_pred + (1-y_real)*(1-y_pred)
      alpha_t = alpha*y_real + (1-alpha)*(1-y_real)
      focal_weight = alpha_t*(1-p_t)**gamma
      loss = (focal_weight*ce_loss).mean()
      return loss
    
    

    Espero que esto ayude a lograr la salida deseada.

Comments are closed.