skpro.model_selection.GridSearchCV#
- class skpro.model_selection.GridSearchCV(estimator, cv, param_grid, scoring=None, n_jobs=None, refit=True, verbose=0, return_n_best_estimators=1, pre_dispatch='2*n_jobs', backend='loky', error_score=nan)[source]#
Perform grid-search cross-validation to find optimal model parameters.
The estimator is fit on the initial window and then temporal cross-validation is used to find the optimal parameter.
Grid-search cross-validation is performed based on a cross-validation iterator encoding the cross-validation scheme, the parameter grid to search over, and (optionally) the evaluation metric for comparing model performance. As in scikit-learn, tuning works through the common hyper-parameter interface which allows to repeatedly fit and evaluate the same estimator with different hyper-parameters.
- Parameters:
- estimatorestimator object
The estimator should implement the skpro estimator interface. Either the estimator must contain a “score” function, or a scoring function must be passed.
- cvcross-validation generator or an iterable
e.g. KFold(n_splits=3)
- param_griddict or list of dictionaries
Model tuning parameters of the estimator to evaluate
- scoringskpro metric (BaseMetric), str, or callable, optional (default=None)
scoring metric to use in tuning the estimator
skpro metric objects (BaseMetric) descendants can be searched
with the
registry.all_estimators
search utility, for instance viaall_estimators("metric", as_dataframe=True)
If callable, must have signature
(y_true: pd.DataFrame, y_pred: BaseDistribution) -> float, assuming y_true, y_pred are of the same length, lower being better, Metrics in skpro.metrics are all of this form.
If str, uses registry.resolve_alias to resolve to one of the above. Valid strings are valid registry.craft specs, which include string repr-s of any BaseMetric object, e.g., “MeanSquaredError()”; and keys of registry.ALIAS_DICT referring to metrics.
If None, defaults to CRPS()
- n_jobs: int, optional (default=None)
Number of jobs to run in parallel. None means 1 unless in a joblib.parallel_backend context. -1 means using all processors.
- refitbool, optional (default=True)
True = refit the estimator with the best parameters on the entire data in fit False = no refitting takes place. The estimator cannot be used to predict. This is to be used to tune the hyperparameters, and then use the estimator as a parameter estimator, e.g., via get_fitted_params or PluginParamsestimator.
- verbose: int, optional (default=0)
- return_n_best_estimatorsint, default=1
In case the n best estimator should be returned, this value can be set and the n best estimators will be assigned to n_best_estimators_
- pre_dispatchstr, optional (default=’2*n_jobs’)
- error_scorenumeric value or the str ‘raise’, optional (default=np.nan)
The test score returned when a estimator fails to be fitted.
- return_train_scorebool, optional (default=False)
- backend{“dask”, “loky”, “multiprocessing”, “threading”}, by default “loky”.
Runs parallel evaluate if specified and strategy is set as “refit”.
“None”: executes loop sequentally, simple list comprehension
“loky”, “multiprocessing” and “threading”: uses
joblib.Parallel
loops“joblib”: custom and 3rd party
joblib
backends, e.g.,spark
“dask”: uses
dask
, requiresdask
package in environment
Recommendation: Use “dask” or “loky” for parallel evaluate. “threading” is unlikely to see speed ups due to the GIL and the serialization backend (
cloudpickle
) for “dask” and “loky” is generally more robust than the standardpickle
library used in “multiprocessing”.- error_score“raise” or numeric, default=np.nan
Value to assign to the score if an exception occurs in estimator fitting. If set to “raise”, the exception is raised. If a numeric value is given, FitFailedWarning is raised.
- backend_paramsdict, optional
additional parameters passed to the backend as config. Directly passed to
utils.parallel.parallelize
. Valid keys depend on the value ofbackend
:“None”: no additional parameters,
backend_params
is ignored“loky”, “multiprocessing” and “threading”: default
joblib
backends any valid keys forjoblib.Parallel
can be passed here, e.g.,n_jobs
, with the exception ofbackend
which is directly controlled bybackend
. Ifn_jobs
is not passed, it will default to-1
, other parameters will default tojoblib
defaults.“joblib”: custom and 3rd party
joblib
backends, e.g.,spark
. any valid keys forjoblib.Parallel
can be passed here, e.g.,n_jobs
,backend
must be passed as a key ofbackend_params
in this case. Ifn_jobs
is not passed, it will default to-1
, other parameters will default tojoblib
defaults.“dask”: any valid keys for
dask.compute
can be passed, e.g.,scheduler
Examples
>>> import pandas as pd >>> from sklearn.datasets import load_diabetes >>> from sklearn.linear_model import LinearRegression >>> from sklearn.model_selection import KFold, ShuffleSplit, train_test_split
>>> from skpro.metrics import CRPS >>> from skpro.model_selection import GridSearchCV >>> from skpro.regression.residual import ResidualDouble
>>> X, y = load_diabetes(return_X_y=True, as_frame=True) >>> y = pd.DataFrame(y) >>> X_train, X_test, y_train, y_test = train_test_split(X, y)
>>> cv = KFold(n_splits=3)
>>> estimator = ResidualDouble(LinearRegression()) >>> param_grid = {"estimator__fit_intercept" : [True, False]} >>> gscv = GridSearchCV( ... estimator=estimator, ... param_grid=param_grid, ... cv=cv, ... scoring=CRPS(), ... ) >>> gscv.fit(X_train, y_train) GridSearchCV(...) >>> y_pred = gscv.predict(X_test) >>> y_pred_proba = gscv.predict_proba(X_test)
- Attributes:
- best_index_int
- best_score_: float
Score of the best model
- best_params_dict
Best parameter values across the parameter grid
- best_estimator_estimator
Fitted estimator with the best parameters
- cv_results_dict
Results from grid search cross validation
- n_splits_: int
Number of splits in the data for cross validation
- refit_time_float
Time (seconds) to refit the best estimator
- scorer_function
Function used to score model
- n_best_estimators_: list of tuples (“rank”, <estimator>)
The “rank” is in relation to best_estimator_
- n_best_scores_: list of float
The scores of n_best_estimators_ sorted from best to worst score of estimators
Methods
Check if the estimator has been fitted.
clone
()Obtain a clone of the object with same hyper-parameters.
clone_tags
(estimator[, tag_names])Clone tags from another estimator as dynamic override.
create_test_instance
([parameter_set])Construct Estimator instance if possible.
create_test_instances_and_names
([parameter_set])Create list of all test instances and a list of names for them.
fit
(X, y[, C])Fit regressor to training data.
get_class_tag
(tag_name[, tag_value_default])Get a class tag's value.
Get class tags from the class and all its parent classes.
Get config flags for self.
get_fitted_params
([deep])Get fitted parameters.
Get object's parameter defaults.
Get object's parameter names.
get_params
([deep])Get a dict of parameters values for this object.
get_tag
(tag_name[, tag_value_default, ...])Get tag value from estimator class and dynamic tag overrides.
get_tags
()Get tags from estimator class and dynamic tag overrides.
get_test_params
([parameter_set])Return testing parameter settings for the estimator.
Check if the object is composed of other BaseObjects.
predict
(X)Predict labels for data from features.
predict_interval
([X, coverage])Compute/return interval predictions.
Predict distribution over labels for data from features.
predict_quantiles
([X, alpha])Compute/return quantile predictions.
predict_var
([X])Compute/return variance predictions.
reset
()Reset the object to a clean post-init state.
set_config
(**config_dict)Set config flags to given values.
set_params
(**params)Set the parameters of this object.
set_random_state
([random_state, deep, ...])Set random_state pseudo-random seed parameters for self.
set_tags
(**tag_dict)Set dynamic tags to given values.
- classmethod get_test_params(parameter_set='default')[source]#
Return testing parameter settings for the estimator.
- Parameters:
- parameter_setstr, default=”default”
Name of the set of test parameters to return, for use in tests. If no special parameters are defined for a value, will return “default” set.
- Returns:
- paramsdict or list of dict
- check_is_fitted()[source]#
Check if the estimator has been fitted.
Inspects object’s _is_fitted attribute that should initialize to False during object construction, and be set to True in calls to an object’s fit method.
- Raises:
- NotFittedError
If the estimator has not been fitted yet.
- clone()[source]#
Obtain a clone of the object with same hyper-parameters.
A clone is a different object without shared references, in post-init state. This function is equivalent to returning sklearn.clone of self.
- Raises:
- RuntimeError if the clone is non-conforming, due to faulty
__init__
.
- RuntimeError if the clone is non-conforming, due to faulty
Notes
If successful, equal in value to
type(self)(**self.get_params(deep=False))
.
- clone_tags(estimator, tag_names=None)[source]#
Clone tags from another estimator as dynamic override.
- Parameters:
- estimatorestimator inheriting from :class:BaseEstimator
- tag_namesstr or list of str, default = None
Names of tags to clone. If None then all tags in estimator are used as tag_names.
- Returns:
- Self
Reference to self.
Notes
Changes object state by setting tag values in tag_set from estimator as dynamic tags in self.
- classmethod create_test_instance(parameter_set='default')[source]#
Construct Estimator instance if possible.
- Parameters:
- parameter_setstr, default=”default”
Name of the set of test parameters to return, for use in tests. If no special parameters are defined for a value, will return “default” set.
- Returns:
- instanceinstance of the class with default parameters
Notes
get_test_params can return dict or list of dict. This function takes first or single dict that get_test_params returns, and constructs the object with that.
- classmethod create_test_instances_and_names(parameter_set='default')[source]#
Create list of all test instances and a list of names for them.
- Parameters:
- parameter_setstr, default=”default”
Name of the set of test parameters to return, for use in tests. If no special parameters are defined for a value, will return “default” set.
- Returns:
- objslist of instances of cls
i-th instance is cls(**cls.get_test_params()[i])
- nameslist of str, same length as objs
i-th element is name of i-th instance of obj in tests convention is {cls.__name__}-{i} if more than one instance otherwise {cls.__name__}
- fit(X, y, C=None)[source]#
Fit regressor to training data.
- Writes to self:
Sets fitted model attributes ending in “_”.
Changes state to “fitted” = sets is_fitted flag to True
- Parameters:
- Xpandas DataFrame
feature instances to fit regressor to
- ypd.DataFrame, must be same length as X
labels to fit regressor to
- Cignored, optional (default=None)
censoring information for survival analysis All probabilistic regressors assume data to be uncensored
- Returns:
- selfreference to self
- classmethod get_class_tag(tag_name, tag_value_default=None)[source]#
Get a class tag’s value.
Does not return information from dynamic tags (set via set_tags or clone_tags) that are defined on instances.
- Parameters:
- tag_namestr
Name of tag value.
- tag_value_defaultany
Default/fallback value if tag is not found.
- Returns:
- tag_value
Value of the tag_name tag in self. If not found, returns tag_value_default.
- classmethod get_class_tags()[source]#
Get class tags from the class and all its parent classes.
Retrieves tag: value pairs from _tags class attribute. Does not return information from dynamic tags (set via set_tags or clone_tags) that are defined on instances.
- Returns:
- collected_tagsdict
Dictionary of class tag name: tag value pairs. Collected from _tags class attribute via nested inheritance.
- get_config()[source]#
Get config flags for self.
- Returns:
- config_dictdict
Dictionary of config name : config value pairs. Collected from _config class attribute via nested inheritance and then any overrides and new tags from _onfig_dynamic object attribute.
- get_fitted_params(deep=True)[source]#
Get fitted parameters.
- State required:
Requires state to be “fitted”.
- Parameters:
- deepbool, default=True
Whether to return fitted parameters of components.
If True, will return a dict of parameter name : value for this object, including fitted parameters of fittable components (= BaseEstimator-valued parameters).
If False, will return a dict of parameter name : value for this object, but not include fitted parameters of components.
- Returns:
- fitted_paramsdict with str-valued keys
Dictionary of fitted parameters, paramname : paramvalue keys-value pairs include:
always: all fitted parameters of this object, as via get_param_names values are fitted parameter value for that key, of this object
if deep=True, also contains keys/value pairs of component parameters parameters of components are indexed as [componentname]__[paramname] all parameters of componentname appear as paramname with its value
if deep=True, also contains arbitrary levels of component recursion, e.g., [componentname]__[componentcomponentname]__[paramname], etc
- classmethod get_param_defaults()[source]#
Get object’s parameter defaults.
- Returns:
- default_dict: dict[str, Any]
Keys are all parameters of cls that have a default defined in __init__ values are the defaults, as defined in __init__.
- classmethod get_param_names()[source]#
Get object’s parameter names.
- Returns:
- param_names: list[str]
Alphabetically sorted list of parameter names of cls.
- get_params(deep=True)[source]#
Get a dict of parameters values for this object.
- Parameters:
- deepbool, default=True
Whether to return parameters of components.
If True, will return a dict of parameter name : value for this object, including parameters of components (= BaseObject-valued parameters).
If False, will return a dict of parameter name : value for this object, but not include parameters of components.
- Returns:
- paramsdict with str-valued keys
Dictionary of parameters, paramname : paramvalue keys-value pairs include:
always: all parameters of this object, as via get_param_names values are parameter value for that key, of this object values are always identical to values passed at construction
if deep=True, also contains keys/value pairs of component parameters parameters of components are indexed as [componentname]__[paramname] all parameters of componentname appear as paramname with its value
if deep=True, also contains arbitrary levels of component recursion, e.g., [componentname]__[componentcomponentname]__[paramname], etc
- get_tag(tag_name, tag_value_default=None, raise_error=True)[source]#
Get tag value from estimator class and dynamic tag overrides.
- Parameters:
- tag_namestr
Name of tag to be retrieved
- tag_value_defaultany type, optional; default=None
Default/fallback value if tag is not found
- raise_errorbool
whether a ValueError is raised when the tag is not found
- Returns:
- tag_valueAny
Value of the tag_name tag in self. If not found, returns an error if raise_error is True, otherwise it returns tag_value_default.
- Raises:
- ValueError if raise_error is True i.e. if tag_name is not in
- self.get_tags().keys()
- get_tags()[source]#
Get tags from estimator class and dynamic tag overrides.
- Returns:
- collected_tagsdict
Dictionary of tag name : tag value pairs. Collected from _tags class attribute via nested inheritance and then any overrides and new tags from _tags_dynamic object attribute.
- is_composite()[source]#
Check if the object is composed of other BaseObjects.
A composite object is an object which contains objects, as parameters. Called on an instance, since this may differ by instance.
- Returns:
- composite: bool
Whether an object has any parameters whose values are BaseObjects.
- property is_fitted[source]#
Whether fit has been called.
Inspects object’s _is_fitted attribute that should initialize to False during object construction, and be set to True in calls to an object’s fit method.
- Returns:
- bool
Whether the estimator has been fit.
- predict(X)[source]#
Predict labels for data from features.
- State required:
Requires state to be “fitted”.
- Accesses in self:
Fitted model attributes ending in “_”
- Parameters:
- Xpandas DataFrame, must have same columns as X in fit
data to predict labels for
- Returns:
- ypandas DataFrame, same length as X
labels predicted for X
- predict_interval(X=None, coverage=0.9)[source]#
Compute/return interval predictions.
If coverage is iterable, multiple intervals will be calculated.
- State required:
Requires state to be “fitted”.
- Accesses in self:
Fitted model attributes ending in “_”.
- Parameters:
- Xpandas DataFrame, must have same columns as X in fit
data to predict labels for
- coveragefloat or list of float of unique values, optional (default=0.90)
nominal coverage(s) of predictive interval(s)
- Returns:
- pred_intpd.DataFrame
Column has multi-index: first level is variable name from
y
in fit, second level coverage fractions for which intervals were computed, in the same order as in input coverage. Third level is string “lower” or “upper”, for lower/upper interval end. Row index is equal to row index ofX
. Entries are lower/upper bounds of interval predictions, for var in col index, at nominal coverage in second col index, lower/upper depending on third col index, for the row index. Upper/lower interval end are equivalent to quantile predictions at alpha = 0.5 - c/2, 0.5 + c/2 for c in coverage.
- predict_proba(X)[source]#
Predict distribution over labels for data from features.
- State required:
Requires state to be “fitted”.
- Accesses in self:
Fitted model attributes ending in “_”
- Parameters:
- Xpandas DataFrame, must have same columns as X in fit
data to predict labels for
- Returns:
- yskpro BaseDistribution, same length as X
labels predicted for X
- predict_quantiles(X=None, alpha=None)[source]#
Compute/return quantile predictions.
If alpha is iterable, multiple quantiles will be calculated.
- State required:
Requires state to be “fitted”.
- Accesses in self:
Fitted model attributes ending in “_”.
- Parameters:
- Xpandas DataFrame, must have same columns as X in fit
data to predict labels for
- alphafloat or list of float of unique values, optional (default=[0.05, 0.95])
A probability or list of, at which quantile predictions are computed.
- Returns:
- quantilespd.DataFrame
Column has multi-index: first level is variable name from
y
in fit, second level being the values of alpha passed to the function. Row index is equal to row index ofX
. Entries are quantile predictions, for var in col index, at quantile probability in second col index, for the row index.
- predict_var(X=None)[source]#
Compute/return variance predictions.
- State required:
Requires state to be “fitted”.
- Accesses in self:
Fitted model attributes ending in “_”.
- Parameters:
- Xpandas DataFrame, must have same columns as X in fit
data to predict labels for
- Returns:
- pred_varpd.DataFrame
Column names are exactly those of
y
passed infit
. Row index is equal to row index ofX
. Entries are variance prediction, for var in col index. A variance prediction for given variable and fh index is a predicted variance for that variable and index, given observed data.
- reset()[source]#
Reset the object to a clean post-init state.
Using reset, runs __init__ with current values of hyper-parameters (result of get_params). This Removes any object attributes, except:
hyper-parameters = arguments of __init__
object attributes containing double-underscores, i.e., the string “__”
Class and object methods, and class attributes are also unaffected.
- Returns:
- self
Instance of class reset to a clean post-init state but retaining the current hyper-parameter values.
Notes
Equivalent to sklearn.clone but overwrites self. After self.reset() call, self is equal in value to type(self)(**self.get_params(deep=False))
- set_config(**config_dict)[source]#
Set config flags to given values.
- Parameters:
- config_dictdict
Dictionary of config name : config value pairs.
- Returns:
- selfreference to self.
Notes
Changes object state, copies configs in config_dict to self._config_dynamic.
- set_params(**params)[source]#
Set the parameters of this object.
The method works on simple estimators as well as on composite objects. Parameter key strings
<component>__<parameter>
can be used for composites, i.e., objects that contain other objects, to access<parameter>
in the component<component>
. The string<parameter>
, without<component>__
, can also be used if this makes the reference unambiguous, e.g., there are no two parameters of components with the name<parameter>
.- Parameters:
- **paramsdict
BaseObject parameters, keys must be
<component>__<parameter>
strings. __ suffixes can alias full strings, if unique among get_params keys.
- Returns:
- selfreference to self (after parameters have been set)
- set_random_state(random_state=None, deep=True, self_policy='copy')[source]#
Set random_state pseudo-random seed parameters for self.
Finds
random_state
named parameters viaestimator.get_params
, and sets them to integers derived fromrandom_state
viaset_params
. These integers are sampled from chain hashing viasample_dependent_seed
, and guarantee pseudo-random independence of seeded random generators.Applies to
random_state
parameters inestimator
depending onself_policy
, and remaining component estimators if and only ifdeep=True
.Note: calls
set_params
even ifself
does not have arandom_state
, or none of the components have arandom_state
parameter. Therefore,set_random_state
will reset anyscikit-base
estimator, even those without arandom_state
parameter.- Parameters:
- random_stateint, RandomState instance or None, default=None
Pseudo-random number generator to control the generation of the random integers. Pass int for reproducible output across multiple function calls.
- deepbool, default=True
Whether to set the random state in sub-estimators. If False, will set only
self
’srandom_state
parameter, if exists. If True, will setrandom_state
parameters in sub-estimators as well.- self_policystr, one of {“copy”, “keep”, “new”}, default=”copy”
“copy” :
estimator.random_state
is set to inputrandom_state
“keep” :
estimator.random_state
is kept as is“new” :
estimator.random_state
is set to a new random state,
derived from input
random_state
, and in general different from it
- Returns:
- selfreference to self