es.davy.ai

Preguntas y respuestas de programación confiables

¿Tienes una pregunta?

Si tienes alguna pregunta, puedes hacerla a continuación o ingresar lo que estás buscando.

Arreglar un parámetro en un estimador de scikit-learn.

Necesito arreglar el valor de un parámetro de un estimador de scikit-learn. Todavía necesito poder cambiar todos los demás parámetros del estimador, y utilizar el estimador dentro de herramientas de scikit-learn como Pipelines y GridSearchCV.

Intenté definir una nueva clase heredando de un estimador de scikit-learn. Por ejemplo, aquí estoy intentando crear una nueva clase que arregle n_estimators=5 de un RandomForestClassifier.

class FiveTreesClassifier(RandomForestClassifier):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.n_estimators = 5

fivetrees = FiveTreesClassifier()
randomforest = RandomForestClassifier(n_estimators=5)

# Esto pasa.
assert fivetrees.n_estimators == randomforest.n_estimators

# Esto falla: los parámetros de fivetrees son un diccionario vacío.
assert fivetrees.get_params() == randomforest.get_params()

El hecho de que get_params() no sea confiable significa que no puedo usar el nuevo estimador dentro de Pipelines y GridSearchCV (como se explica aquí).

Estoy utilizando scikit-learn 0.24.2, pero creo que en realidad sería lo mismo con versiones más nuevas.

Preferiría respuestas que me permitan definir una nueva clase mientras arreglo el valor de un hiperparámetro. También aceptaría respuestas que utilicen otras técnicas. También agradecería explicaciones detalladas de por qué debería / no debería hacer esto.

Tags:  , , ,

Answer

  1. Avatar for davy.ai

    Un enfoque para fijar el valor de un parámetro de un estimator de scikit-learn y al mismo tiempo seguir siendo capaz de cambiar todos los demás parámetros y usar el estimator dentro de Pipelines y GridSearchCV es usar la función partial del módulo functools en Python.

    Aquí hay un ejemplo usando partial para fijar n_estimators=5 de un RandomForestClassifier:

    from functools import partial
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.pipeline import Pipeline
    from sklearn.model_selection import GridSearchCV
    
    # Establecer n_estimators=5 como valor predeterminado utilizando functools.partial
    RandomForestClassifier = partial(RandomForestClassifier, n_estimators=5)
    
    # Definir un Pipeline con el RandomForestClassifier modificado
    pipeline = Pipeline([
      ('clf', RandomForestClassifier())
    ])
    
    # Definir una cuadrícula de parámetros para GridSearchCV
    param_grid = {
      'clf__max_depth': [2, 4, 6, 8]
    }
    
    # Crear y ajustar un objeto GridSearchCV con el estimator modificado
    grid_search = GridSearchCV(pipeline, param_grid=param_grid, cv=5)
    grid_search.fit(X, y)
    

    En este ejemplo, partial se utiliza para crear una nueva función que se comporta exactamente como RandomForestClassifier, pero con n_estimators fijo en 5. Esta nueva función luego se utiliza para definir un Pipeline y para crear un objeto GridSearchCV. El resto del código es igual que el habitual cuando se trabaja con Pipelines y GridSearchCV.

    Con este enfoque, no estamos creando una nueva clase, lo que tiene algunas ventajas: no tenemos que preocuparnos por el método get_params(), podemos usar cualquier estimator de scikit-learn como base sin tener que crear una subclase para cada uno, y podemos fijar fácilmente cualquier hiperparámetro de un estimator de manera similar.

    Sin embargo, una desventaja de este enfoque es que el estimator modificado no es una subclase del estimator original (aunque todavía es una instancia de la misma clase), por lo que algunos métodos de scikit-learn que dependen de la herencia pueden comportarse de manera diferente (aunque esto es poco probable que sea un problema en la mayoría de los casos).

Comments are closed.