es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

sklearn.plot_tree cómo visualizar class_labels para una tarea de clasificación?

Recientemente he realizado mi código de prueba en un Árbol de Decisión. Funciona perfectamente bien, excepto por una cosa. El árbol que se está graficando no contiene los nombres de las clases en él. ¿Estoy haciendo algo mal?

Por favor, vea el código a continuación y una imagen del conjunto de datos.

Imagen 1

Imagen 2

#Importar Datos#

import pandas as pd

data_set = pd.read_excel(r”C:\Users\User\Desktop\Tree.xlsx”)

print(data_set.head())

#Definir Características y Objetivos de Entrenamiento#

features_names=[“Money”,”Debt”]
target_names=[“Mood1”, “Mood2”, “Mood3”]

features = data_set[features_names]
targets = data_set[target_names]

print(features)
print(targets)

#Definir Conjunto de Entrenamiento y Conjunto de Prueba#

train_features = features[:10]
train_targets = targets[:10]

test_features = features[10:]
test_targets = targets[10:]

print (train_features)
print (train_targets)

print(test_features)
print(test_targets)

#Estimación del Árbol#

from sklearn.tree import DecisionTreeRegressor

dt = DecisionTreeRegressor(max_depth = 3)
dt = dt.fit(train_features, train_targets)

print(dt.score(train_features, train_targets))
print(dt.score(test_features, test_targets))

#Graficar el Árbol#

from sklearn import tree
import matplotlib.pyplot as plt

tree.plot_tree(dt, feature_names=features_names, class_names=target_names, filled = True)
plt.show()

Tags:  , , , ,

Answer

  1. Avatar for davy.ai

    Para agregar nombres de clase al gráfico del árbol de decisión, debes asegurarte de que tus nombres de clase estén en el formato correcto y se pasen correctamente al parámetro ‘class_names’ de la función ‘plot_tree’. Según el código que proporcionaste, parece que estás pasando los nombres de clase correctos a la función ‘plot_tree’. Por lo tanto, es posible que el problema radique en la versión de la biblioteca ‘matplotlib’ que estás utilizando.

    A partir de la versión 3.3.0 de ‘matplotlib’, la función ‘plot_tree’ tiene un error que hace que no se muestren los nombres de clase en el gráfico. Para solucionar este problema, puedes degradar a una versión anterior de ‘matplotlib’ o agregar manualmente los nombres de clase al gráfico.

    Para agregar manualmente los nombres de clase al gráfico, puedes utilizar la función ‘text’ de la biblioteca ‘matplotlib’ para agregar los nombres de clase como anotaciones al gráfico. Aquí tienes un ejemplo de cómo hacerlo:

    plt.figure(figsize=(10, 10))
    tree.plot_tree(dt, feature_names=features_names, class_names=target_names, filled=True)
    
    # Agregar nombres de clase como anotaciones al gráfico
    plt.text(0.5, 1.0, "Mood1", ha='center', va='center', color='white', fontsize=14, bbox=dict(facecolor='blue', edgecolor='blue', boxstyle='round'))
    plt.text(2.0, 1.0, "Mood2", ha='center', va='center', color='white', fontsize=14, bbox=dict(facecolor='orange', edgecolor='orange', boxstyle='round'))
    plt.text(3.5, 1.0, "Mood3", ha='center', va='center', color='white', fontsize=14, bbox=dict(facecolor='green', edgecolor='green', boxstyle='round'))
    
    plt.show()
    

    Este código crea una nueva figura con un tamaño especificado, traza el árbol de decisión con ‘plot_tree’ y luego agrega los nombres de clase como anotaciones utilizando la función ‘text’. El parámetro ‘bbox’ de la función ‘text’ se utiliza para crear un cuadro redondeado alrededor del nombre de clase para que sea más visible.

    ¡Espero que esto te ayude a resolver tu problema!

Comments are closed.