Arreglar un parámetro en un estimador de scikit-learn.
Necesito arreglar el valor de un parámetro de un estimador de scikit-learn. Todavía necesito poder cambiar todos los demás parámetros del estimador, y utilizar el estimador dentro de herramientas de scikit-learn como Pipelines y GridSearchCV.
Intenté definir una nueva clase heredando de un estimador de scikit-learn. Por ejemplo, aquí estoy intentando crear una nueva clase que arregle n_estimators=5
de un RandomForestClassifier.
class FiveTreesClassifier(RandomForestClassifier):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.n_estimators = 5
fivetrees = FiveTreesClassifier()
randomforest = RandomForestClassifier(n_estimators=5)
# Esto pasa.
assert fivetrees.n_estimators == randomforest.n_estimators
# Esto falla: los parámetros de fivetrees son un diccionario vacío.
assert fivetrees.get_params() == randomforest.get_params()
El hecho de que get_params()
no sea confiable significa que no puedo usar el nuevo estimador dentro de Pipelines y GridSearchCV (como se explica aquí).
Estoy utilizando scikit-learn 0.24.2, pero creo que en realidad sería lo mismo con versiones más nuevas.
Preferiría respuestas que me permitan definir una nueva clase mientras arreglo el valor de un hiperparámetro. También aceptaría respuestas que utilicen otras técnicas. También agradecería explicaciones detalladas de por qué debería / no debería hacer esto.
davy.ai
Un enfoque para fijar el valor de un parámetro de un estimator de scikit-learn y al mismo tiempo seguir siendo capaz de cambiar todos los demás parámetros y usar el estimator dentro de Pipelines y GridSearchCV es usar la función
partial
del módulofunctools
en Python.Aquí hay un ejemplo usando
partial
para fijarn_estimators=5
de unRandomForestClassifier
:En este ejemplo,
partial
se utiliza para crear una nueva función que se comporta exactamente comoRandomForestClassifier
, pero conn_estimators
fijo en 5. Esta nueva función luego se utiliza para definir un Pipeline y para crear un objeto GridSearchCV. El resto del código es igual que el habitual cuando se trabaja con Pipelines y GridSearchCV.Con este enfoque, no estamos creando una nueva clase, lo que tiene algunas ventajas: no tenemos que preocuparnos por el método
get_params()
, podemos usar cualquier estimator de scikit-learn como base sin tener que crear una subclase para cada uno, y podemos fijar fácilmente cualquier hiperparámetro de un estimator de manera similar.Sin embargo, una desventaja de este enfoque es que el estimator modificado no es una subclase del estimator original (aunque todavía es una instancia de la misma clase), por lo que algunos métodos de scikit-learn que dependen de la herencia pueden comportarse de manera diferente (aunque esto es poco probable que sea un problema en la mayoría de los casos).