Basic usage¶
SciKeras
is designed to maximize interoperability between sklearn
and Keras/TensorFlow
. The aim is to keep 99% of the flexibility of Keras
while being able to leverage most features of sklearn
. Below, we show the basic usage of SciKeras
and how it can be combined with sklearn
.
This notebook shows you how to use the basic functionality of SciKeras
.
Table of contents¶
1. Setup¶
[1]:
try:
import scikeras
except ImportError:
!python -m pip install scikeras
Silence TensorFlow logging to keep output succinct.
[2]:
import warnings
from tensorflow import get_logger
get_logger().setLevel('ERROR')
warnings.filterwarnings("ignore", message="Setting the random state for TF")
[3]:
import numpy as np
from scikeras.wrappers import KerasClassifier, KerasRegressor
from tensorflow import keras
2. Training a classifier and making predictions¶
2.1 A toy binary classification task¶
We load a toy classification task from sklearn
.
[4]:
import numpy as np
from sklearn.datasets import make_classification
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X.shape, y.shape, y.mean()
[4]:
((1000, 20), (1000,), 0.5)
2.2 Definition of the Keras classification Model¶
We define a vanilla neural network with.
Because we are dealing with 2 classes, the output layer can be constructed in two different ways:
Single unit with a
"sigmoid"
nonlinearity. The loss must be"binary_crossentropy"
.Two units (one for each class) and a
"softmax"
nonlinearity. The loss must be"sparse_categorical_crossentropy"
.
In this example, we choose the first option, which is what you would usually do for binary classification. The second option is usually reserved for when you have >2 classes.
[5]:
from tensorflow import keras
def get_clf(meta, hidden_layer_sizes, dropout):
n_features_in_ = meta["n_features_in_"]
n_classes_ = meta["n_classes_"]
model = keras.models.Sequential()
model.add(keras.layers.Input(shape=(n_features_in_,)))
for hidden_layer_size in hidden_layer_sizes:
model.add(keras.layers.Dense(hidden_layer_size, activation="relu"))
model.add(keras.layers.Dropout(dropout))
model.add(keras.layers.Dense(1, activation="sigmoid"))
return model
2.3 Defining and training the neural net classifier¶
We use KerasClassifier
because we’re dealing with a classifcation task. The first argument should be a callable returning a Keras.Model
, in this case, get_clf
. As additional arguments, we pass the number of loss function (required) and the optimizer, but the later is optional. We must also pass all of the arguments to get_clf
as keyword arguments to KerasClassifier
if they don’t have a default value in get_clf
. Note that if you do not pass an argument to
KerasClassifier
, it will not be avilable for hyperparameter tuning. Finally, we also pass random_state=0
for reproducible results.
[6]:
from scikeras.wrappers import KerasClassifier
clf = KerasClassifier(
model=get_clf,
loss="binary_crossentropy",
hidden_layer_sizes=(100,),
dropout=0.5,
)
As in sklearn
, we call fit
passing the input data X
and the targets y
.
[7]:
clf.fit(X, y);
32/32 [==============================] - 1s 2ms/step - loss: 0.7296
Also, as in sklearn
, you may call predict
or predict_proba
on the fitted model.
2.4 Making predictions, classification¶
[8]:
y_pred = clf.predict(X[:5])
y_pred
1/1 [==============================] - 0s 82ms/step
[8]:
array([1, 0, 0, 0, 0])
[9]:
y_proba = clf.predict_proba(X[:5])
y_proba
1/1 [==============================] - 0s 28ms/step
[9]:
array([[0.48719656, 0.51280344],
[0.79750854, 0.20249146],
[0.857792 , 0.14220795],
[0.82747555, 0.17252442],
[0.8910192 , 0.10898075]], dtype=float32)
3 Training a regressor¶
3.1 A toy regression task¶
[10]:
from sklearn.datasets import make_regression
X_regr, y_regr = make_regression(1000, 20, n_informative=10, random_state=0)
X_regr.shape, y_regr.shape, y_regr.min(), y_regr.max()
[10]:
((1000, 20), (1000,), -649.0148244404172, 615.4505181286091)
3.2 Definition of the Keras regression Model¶
Again, define a vanilla neural network. The main difference is that the output layer always has a single unit and does not apply any nonlinearity.
[11]:
def get_reg(meta, hidden_layer_sizes, dropout):
n_features_in_ = meta["n_features_in_"]
model = keras.models.Sequential()
model.add(keras.layers.Input(shape=(n_features_in_,)))
for hidden_layer_size in hidden_layer_sizes:
model.add(keras.layers.Dense(hidden_layer_size, activation="relu"))
model.add(keras.layers.Dropout(dropout))
model.add(keras.layers.Dense(1))
return model
3.3 Defining and training the neural net regressor¶
Training a regressor has nearly the same data flow as training a classifier. The differences include using KerasRegressor
instead of KerasClassifier
and adding KerasRegressor.r_squared
as a metric. Most of the Scikit-learn regressors use the coefficient of determination or R^2 as a metric function, which measures correlation between the true labels and predicted labels.
[12]:
from scikeras.wrappers import KerasRegressor
reg = KerasRegressor(
model=get_reg,
loss="mse",
metrics=[KerasRegressor.r_squared],
hidden_layer_sizes=(100,),
dropout=0.5,
)
[13]:
reg.fit(X_regr, y_regr);
32/32 [==============================] - 1s 2ms/step - loss: 45153.5938 - r_squared: -0.0519
3.4 Making predictions, regression¶
You may call predict
or predict_proba
on the fitted model. For regressions, both methods return the same value.
[14]:
y_pred = reg.predict(X_regr[:5])
y_pred
1/1 [==============================] - 0s 48ms/step
[14]:
array([ 0.05222758, 0.4514417 , -0.48679325, -0.08914306, -0.08563218],
dtype=float32)
4. Saving and loading a model¶
Save and load either the whole model by using pickle, or use Keras’ specialized save methods on the KerasClassifier.model_
or KerasRegressor.model_
attribute that is created after fitting. You will want to use Keras’ model saving utilities if any of the following apply:
You wish to save only the weights or only the training configuration of your model.
You wish to share your model with collaborators. Pickle is a relatively unsafe protocol and it is not recommended to share or load pickle objects publically.
You care about performance, especially if doing in-memory serialization.
For more information, see Keras’ saving documentation.
4.1 Saving the whole model¶
[15]:
import pickle
bytes_model = pickle.dumps(reg)
new_reg = pickle.loads(bytes_model)
new_reg.predict(X_regr[:5]) # model is still trained
1/1 [==============================] - 0s 49ms/step
[15]:
array([ 0.05222758, 0.4514417 , -0.48679325, -0.08914306, -0.08563218],
dtype=float32)
4.2 Saving using Keras’ saving methods¶
This efficiently and safely saves the model to disk, including trained weights. You should use this method if you plan on sharing your saved models.
[16]:
# Save to disk
pred_old = reg.predict(X_regr)
reg.model_.save("/tmp/my_model") # saves just the Keras model
32/32 [==============================] - 0s 1ms/step
[17]:
# Load the model back into memory
new_reg_model = keras.models.load_model("/tmp/my_model")
# Now we need to instantiate a new SciKeras object
# since we only saved the Keras model
reg_new = KerasRegressor(new_reg_model)
# use initialize to avoid re-fitting
reg_new.initialize(X_regr, y_regr)
pred_new = reg_new.predict(X_regr)
np.testing.assert_allclose(pred_old, pred_new)
32/32 [==============================] - 0s 1ms/step
5. Usage with an sklearn Pipeline¶
It is possible to put the KerasClassifier
inside an sklearn Pipeline
, as you would with any sklearn
classifier.
[18]:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
pipe = Pipeline([
('scale', StandardScaler()),
('clf', clf),
])
y_proba = pipe.fit(X, y).predict(X)
32/32 [==============================] - 1s 2ms/step - loss: 0.7097
32/32 [==============================] - 0s 1ms/step
To save the whole pipeline, including the Keras model, use pickle
.
6. Callbacks¶
Adding a new callback to the model is straightforward. Below we define a threashold callback to avoid training past a certain accuracy. This a rudimentary for of early stopping.
[19]:
class MaxValLoss(keras.callbacks.Callback):
def __init__(self, monitor: str, threashold: float):
self.monitor = monitor
self.threashold = threashold
def on_epoch_end(self, epoch, logs=None):
if logs[self.monitor] > self.threashold:
print("Threashold reached; stopping training")
self.model.stop_training = True
Define a test dataset:
[20]:
from sklearn.datasets import make_moons
X, y = make_moons(n_samples=100, noise=0.2, random_state=0)
And try fitting it with and without the callback:
[21]:
kwargs = dict(
model=get_clf,
loss="binary_crossentropy",
dropout=0.5,
hidden_layer_sizes=(100,),
metrics=["binary_accuracy"],
fit__validation_split=0.2,
epochs=20,
verbose=False,
random_state=0
)
# First test without the callback
clf = KerasClassifier(**kwargs)
clf.fit(X, y)
print(f"Trained {len(clf.history_['loss'])} epochs")
print(f"Final accuracy: {clf.history_['val_binary_accuracy'][-1]}") # get last value of last fit/partial_fit call
Trained 20 epochs
Final accuracy: 1.0
And with:
[22]:
# Test with the callback
cb = MaxValLoss(monitor="val_binary_accuracy", threashold=0.75)
clf = KerasClassifier(
**kwargs,
callbacks=[cb]
)
clf.fit(X, y)
print(f"Trained {len(clf.history_['loss'])} epochs")
print(f"Final accuracy: {clf.history_['val_binary_accuracy'][-1]}") # get last value of last fit/partial_fit call
Threashold reached; stopping training
Trained 2 epochs
Final accuracy: 0.949999988079071
For information on how to write custom callbacks, have a look at the Advanced Usage notebook.
7. Usage with sklearn GridSearchCV¶
7.1 Special prefixes¶
SciKeras allows to direct access to all parameters passed to the wrapper constructors, including deeply nested routed parameters. This allows tunning of paramters like hidden_layer_sizes
as well as optimizer__learning_rate
.
This is exactly the same logic that allows to access estimator parameters in sklearn Pipeline
s and FeatureUnion
s.
This feature is useful in several ways. For one, it allows to set those parameters in the model definition. Furthermore, it allows you to set parameters in an sklearn GridSearchCV
as shown below.
To differentiate paramters like callbacks
which are accepted by both tf.keras.Model.fit
and tf.keras.Model.predict
you can add a fit__
or predict__
routing suffix respectively. Similar, the model__
prefix may be used to specify that a paramter is destined only for get_clf
/get_reg
(or whatever callable you pass as your model
argument).
For more information on parameter routing with special prefixes, see the Advanced Usage Docs
7.2 Performing a grid search¶
Below we show how to perform a grid search over the learning rate (optimizer__lr
), the model’s number of hidden layers (model__hidden_layer_sizes
), the model’s dropout rate (model__dropout
).
[23]:
from sklearn.model_selection import GridSearchCV
clf = KerasClassifier(
model=get_clf,
loss="binary_crossentropy",
optimizer="adam",
optimizer__lr=0.1,
model__hidden_layer_sizes=(100,),
model__dropout=0.5,
verbose=False,
)
Note: We set the verbosity level to zero (verbose=False
) to prevent too much print output from being shown.
[24]:
params = {
'optimizer__lr': [0.05, 0.1],
'model__hidden_layer_sizes': [(100, ), (50, 50, )],
'model__dropout': [0, 0.5],
}
gs = GridSearchCV(clf, params, scoring='accuracy', n_jobs=-1, verbose=True)
gs.fit(X, y)
print(gs.best_score_, gs.best_params_)
Fitting 5 folds for each of 8 candidates, totalling 40 fits
/home/runner/work/scikeras/scikeras/.venv/lib/python3.8/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
super(Adam, self).__init__(name, **kwargs)
/home/runner/work/scikeras/scikeras/.venv/lib/python3.8/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
super(Adam, self).__init__(name, **kwargs)
/home/runner/work/scikeras/scikeras/.venv/lib/python3.8/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
super(Adam, self).__init__(name, **kwargs)
/home/runner/work/scikeras/scikeras/.venv/lib/python3.8/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
super(Adam, self).__init__(name, **kwargs)
/home/runner/work/scikeras/scikeras/.venv/lib/python3.8/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
super(Adam, self).__init__(name, **kwargs)
WARNING:tensorflow:5 out of the last 13 calls to <function Model.make_train_function.<locals>.train_function at 0x7f21e077eaf0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f21e3880a60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
/home/runner/work/scikeras/scikeras/.venv/lib/python3.8/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
super(Adam, self).__init__(name, **kwargs)
/home/runner/work/scikeras/scikeras/.venv/lib/python3.8/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
super(Adam, self).__init__(name, **kwargs)
/home/runner/work/scikeras/scikeras/.venv/lib/python3.8/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
super(Adam, self).__init__(name, **kwargs)
/home/runner/work/scikeras/scikeras/.venv/lib/python3.8/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
super(Adam, self).__init__(name, **kwargs)
/home/runner/work/scikeras/scikeras/.venv/lib/python3.8/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
super(Adam, self).__init__(name, **kwargs)
WARNING:tensorflow:5 out of the last 13 calls to <function Model.make_train_function.<locals>.train_function at 0x7f952059daf0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f952069ca60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
/home/runner/work/scikeras/scikeras/.venv/lib/python3.8/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
super(Adam, self).__init__(name, **kwargs)
WARNING:tensorflow:5 out of the last 13 calls to <function Model.make_train_function.<locals>.train_function at 0x7f21e89f4ca0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f21e06b3160> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
/home/runner/work/scikeras/scikeras/.venv/lib/python3.8/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
super(Adam, self).__init__(name, **kwargs)
WARNING:tensorflow:5 out of the last 13 calls to <function Model.make_train_function.<locals>.train_function at 0x7f952280dca0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f95204cb160> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
/home/runner/work/scikeras/scikeras/.venv/lib/python3.8/site-packages/keras/optimizer_v2/adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
super(Adam, self).__init__(name, **kwargs)
0.85 {'model__dropout': 0, 'model__hidden_layer_sizes': (100,), 'optimizer__lr': 0.05}
Of course, we could further nest the KerasClassifier
within an sklearn.pipeline.Pipeline
, in which case we just prefix the parameter by the name of the net (e.g. clf__model__hidden_layer_sizes
).