2023-10-05 16:53:51 +02:00
|
|
|
import inspect
|
2023-10-09 08:32:43 +02:00
|
|
|
|
2023-10-05 16:53:51 +02:00
|
|
|
from .base_estimator import BaseEstimator
|
|
|
|
|
|
|
|
|
|
|
|
def load_estimator(type, estimator):
|
|
|
|
module_path = f"{__name__}.{type}.{estimator}"
|
|
|
|
module = __import__(module_path, fromlist=[""])
|
|
|
|
classes = inspect.getmembers(module, inspect.isclass)
|
|
|
|
# Filter classes defined in the module
|
|
|
|
classes = [c for c in classes if c[1].__module__ == module_path]
|
|
|
|
# Filter classes inherited from BaseModel
|
|
|
|
classes = [c for c in classes if issubclass(c[1], BaseEstimator)]
|
|
|
|
assert len(classes) == 1, classes
|
|
|
|
return classes[0][1]
|