Migrating from tf.keras.wrappers.scikit_learn
¶
Why switch to SciKeras¶
SciKeras has several advantages over tf.keras.wrappers.scikit_learn
:
Full compatibility with the Scikit-Learn API, including grid searches, ensembles, transformers, etc.
Support for Functional and Subclassed Keras Models.
Support for pre-trained models.
Support for dynamically set Keras parameters depending on inputs (e.g. input shape).
Support for hyperparameter tuning of optimizers and losses.
Support for multi-input and multi-ouput Keras models.
Functional random_state for reproducible training.
Many more that you will discover as you use SciKeras!
Changes to your code¶
SciKeras is largely backwards compatible with the existing wrappers. For most cases, you can just change your import statement from:
- from tensorflow.keras.wrappers.scikit_learn import KerasClassifier, KerasRegressor
+ from scikeras.wrappers import KerasClassifier, KerasRegressor
SciKeras does however have some backward incompatible changes:
One-hot encoding of targets for categorical crossentropy losses¶
SciKeras will not longer implicitly inspect your Model’s loss function to determine if it needs to one-hot encode your target to match the loss function. Instead, you must explicitly pass your loss function to the constructor:
clf = KerasClassifier(loss="categorical_crossentropy")
Variable keyword arguments in fit and predict¶
In a future release of SciKeras, variable keyword arguments (commonly referred to as
**kwargs
) will be removed from fit and predict. To future
proof your code, you should instead declare these parameters in your constructor:
- clf = KerasClassifier(...)
- clf.fit(..., batch_size=32)
+ clf = KerasClassifier(..., batch_size=32)
+ clf.fit(...)
Or to declare separate values for fit
and predict
:
clf = KerasClassifier(fit__batch_size=32, predict__batch_size=10000)
Renaming of build_fn
to model
¶
SciKeras renamed the constructor argument build_fn
to model
. In a future release,
passing build_fn
as a _keyword_ argument will raise a TypeError
. Passing it as a positional
argument remains unchanged. You can make the following change to future proof your code:
- clf = KerasClassifier(build_fn=...)
+ clf = KerasClassifier(model=...)
Default arguments in build_fn/model¶
SciKeras will no longer introspect your callable model for user defined parameters
(the behavior for parameters like optimizer
is unchanged).
You must now “declare” them as keyword arguments to the constructor if you want them to be
tunable parameters (i.e. settable via set_params
):
- def get_model(my_param=123):
+ def get_model(my_param): # You can optionally remove the default here
...
return model
- clf = KerasClassifier(get_model)
+ clf = KerasClassifier(get_model, my_param=123) # option 1
+ clf = KerasClassifier(get_model, model__my_param=123) # option 2
That said, if you do not need them to work with set_params
(which is only really
necessary if you are doing hyperparameter tuning), you do not need to make any changes.