Data Transformers¶
Keras support many types of input and output data formats, including:
Multiple inputs
Multiple outputs
Higher-dimensional tensors
In this notebook, we explore how to reconcile this functionality with the sklearn ecosystem via SciKeras data transformer interface.
Table of contents¶
1. Setup¶
[1]:
try:
import scikeras
except ImportError:
!python -m pip install scikeras
Silence TensorFlow warnings to keep output succint.
[2]:
import warnings
from tensorflow import get_logger
get_logger().setLevel('ERROR')
warnings.filterwarnings("ignore", message="Setting the random state for TF")
[3]:
import numpy as np
from scikeras.wrappers import KerasClassifier, KerasRegressor
from tensorflow import keras
2. Data transformer interface¶
SciKeras enables advanced Keras use cases by providing an interface to convert sklearn compliant data to whatever format your Keras model requires within SciKeras, right before passing said data to the Keras model.
This interface is implemented in the form of two sklearn transformers, one for the features (X
) and one for the target (y
). SciKeras loads these transformers via the target_encoder
and feature_encoder
methods.
By default, SciKeras implements target_encoder
for both KerasClassifier and KerasRegressor to facilitate common types of tasks in sklearn. The default implementations are scikeras.utils.transformers.ClassifierLabelEncoder
and scikeras.utils.transformers.RegressorTargetEncoder
for KerasClassifier and KerasRegressor respectively. Information on the types of tasks that these default transformers are able to perform can be found in the SciKeras
docs.
Below is an outline of the inner workings of the data transfomer interfaces to help understand when they are called:
[4]:
if False: # avoid executing pseudocode
from scikeras.utils.transformers import (
ClassifierLabelEncoder,
RegressorTargetEncoder,
)
class BaseWrapper:
def fit(self, X, y):
self.target_encoder_ = self.target_encoder
self.feature_encoder_ = self.feature_encoder
y = self.target_encoder_.fit_transform(y)
X = self.feature_encoder_.fit_transform(X)
self.model_.fit(X, y)
return self
def predict(self, X):
X = self.feature_encoder_.transform(X)
y_pred = self.model_.predict(X)
return self.target_encoder_.inverse_transform(y_pred)
class KerasClassifier(BaseWrapper):
@property
def target_encoder(self):
return ClassifierLabelEncoder(loss=self.loss)
def predict_proba(self, X):
X = self.feature_encoder_.transform(X)
y_pred = self.model_.predict(X)
return self.target_encoder_.inverse_transform(y_pred, return_proba=True)
class KerasRegressor(BaseWrapper):
@property
def target_encoder(self):
return RegressorTargetEncoder()
To substitute your own data transformation routine, you must subclass the wrappers and override one of the encoder defining functions. You will have access to all attributes of the wrappers, and you can pass these to your transformer, like we do above with loss
.
[5]:
from sklearn.base import BaseEstimator, TransformerMixin
[6]:
if False: # avoid executing pseudocode
class MultiOutputTransformer(BaseEstimator, TransformerMixin):
...
class MultiOutputClassifier(KerasClassifier):
@property
def target_encoder(self):
return MultiOutputTransformer(...)
2.1 get_metadata method¶
SciKeras recognized an optional get_metadata
on the transformers. get_metadata
is expected to return a dicionary of with key strings and arbitrary values. SciKeras will set add these items to the wrappers namespace and make them available to your model building function via the meta
keyword argument:
[7]:
if False: # avoid executing pseudocode
class MultiOutputTransformer(BaseEstimator, TransformerMixin):
def get_metadata(self):
return {"my_param_": "foobarbaz"}
class MultiOutputClassifier(KerasClassifier):
@property
def target_encoder(self):
return MultiOutputTransformer(...)
def get_model(meta):
print(f"Got: {meta['my_param_']}")
clf = MultiOutputClassifier(model=get_model)
clf.fit(X, y) # Got: foobarbaz
print(clf.my_param_) # foobarbaz
3. Multiple outputs¶
Keras makes it straight forward to define models with multiple outputs, that is a Model with multiple sets of fully-connected heads at the end of the network. This functionality is only available in the Functional Model and subclassed Model definition modes, and is not available when using Sequential.
In practice, the main thing about Keras models with multiple outputs that you need to know as a SciKeras user is that Keras expects X
or y
to be a list of arrays/tensors, with one array/tensor for each input/output.
Note that “multiple outputs” in Keras has a slightly different meaning than “multiple outputs” in sklearn. Many tasks that would be considered “multiple output” tasks in sklearn can be mapped to a single “output” in Keras with multiple units. This notebook specifically focuses on the cases that require multiple distinct Keras outputs.
3.1 Define Keras Model¶
Here we define a simple perceptron that has two outputs, corresponding to one binary classification taks and one multiclass classification task. For example, one output might be “image has car” (binary) and the other might be “color of car in image” (multiclass).
[8]:
def get_clf_model(meta):
inp = keras.layers.Input(shape=(meta["n_features_in_"]))
x1 = keras.layers.Dense(100, activation="relu")(inp)
out_bin = keras.layers.Dense(1, activation="sigmoid")(x1)
out_cat = keras.layers.Dense(meta["n_classes_"][1], activation="softmax")(x1)
model = keras.Model(inputs=inp, outputs=[out_bin, out_cat])
model.compile(
loss=["binary_crossentropy", "sparse_categorical_crossentropy"]
)
return model
Let’s test that this model works with the kind of inputs and outputs we expect.
[9]:
X = np.random.random(size=(100, 10))
y_bin = np.random.randint(0, 2, size=(100,))
y_cat = np.random.randint(0, 5, size=(100, ))
y = [y_bin, y_cat]
# build mock meta
meta = {
"n_features_in_": 10,
"n_classes_": [2, 5] # note that we made this a list, one for each output
}
model = get_clf_model(meta=meta)
model.fit(X, y, verbose=0)
y_pred = model.predict(X)
[10]:
print(y_pred[0][:2, :])
[[0.39894974]
[0.46238422]]
[11]:
print(y_pred[1][:2, :])
[[0.18794206 0.1951279 0.1590518 0.2910546 0.16682352]
[0.17859802 0.18505742 0.1755632 0.28292775 0.17785361]]
As you can see, our predict
output is also a list of arrays, except it contains probabilities instead of the class predictions.
Our data transormer’s job will be to convert from a single numpy array (which is what the sklearn ecosystem works with) to the list of arrays and then back. Additionally, for classifiers, we will want to be able to convert probabilities to class predictions.
We will structure our data on the sklearn side by column-stacking our list of arrays. This works well in this case since we have the same number of datapoints in each array.
3.2 Define output data transformer¶
Let’s go ahead and protoype this data transformer:
[12]:
from typing import List
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import LabelEncoder
class MultiOutputTransformer(BaseEstimator, TransformerMixin):
def fit(self, y):
y_bin, y_cat = y[:, 0], y[:, 1]
# Create internal encoders to ensure labels are 0, 1, 2...
self.bin_encoder_ = LabelEncoder()
self.cat_encoder_ = LabelEncoder()
# Fit them to the input data
self.bin_encoder_.fit(y_bin)
self.cat_encoder_.fit(y_cat)
# Save the number of classes
self.n_classes_ = [
self.bin_encoder_.classes_.size,
self.cat_encoder_.classes_.size,
]
# Save number of expected outputs in the Keras model
# SciKeras will automatically use this to do error-checking
self.n_outputs_expected_ = 2
return self
def transform(self, y: np.ndarray) -> List[np.ndarray]:
y_bin, y_cat = y[:, 0], y[:, 1]
# Apply transformers to input array
y_bin = self.bin_encoder_.transform(y_bin)
y_cat = self.cat_encoder_.transform(y_cat)
# Split the data into a list
return [y_bin, y_cat]
def inverse_transform(self, y: List[np.ndarray], return_proba: bool = False) -> np.ndarray:
y_pred_proba = y # rename for clarity, what Keras gives us are probs
if return_proba:
return np.column_stack(y_pred_proba, axis=1)
# Get class predictions from probabilities
y_pred_bin = (y_pred_proba[0] > 0.5).astype(int).reshape(-1, )
y_pred_cat = np.argmax(y_pred_proba[1], axis=1)
# Pass back through LabelEncoder
y_pred_bin = self.bin_encoder_.inverse_transform(y_pred_bin)
y_pred_cat = self.cat_encoder_.inverse_transform(y_pred_cat)
return np.column_stack([y_pred_bin, y_pred_cat])
def get_metadata(self):
return {
"n_classes_": self.n_classes_,
"n_outputs_expected_": self.n_outputs_expected_,
}
Note that in addition to the usual transform
and inverse_transform
methods, we implement the get_metadata
method to return the n_classes_
attribute.
Lets test our transformer with the same dataset we previoulsy used to test our model:
[13]:
tf = MultiOutputTransformer()
y_sklearn = np.column_stack(y)
y_keras = tf.fit_transform(y_sklearn)
print("`y`, as will be passed to Keras:")
print([y_keras[0][:4], y_keras[1][:4]])
`y`, as will be passed to Keras:
[array([0, 0, 1, 1]), array([2, 3, 2, 2])]
[14]:
y_pred_sklearn = tf.inverse_transform(y_pred)
print("`y_pred`, as will be returned to sklearn:")
y_pred_sklearn[:5]
`y_pred`, as will be returned to sklearn:
[14]:
array([[0, 3],
[0, 3],
[0, 3],
[0, 3],
[0, 3]])
[15]:
print(f"metadata = {tf.get_metadata()}")
metadata = {'n_classes_': [2, 5], 'n_outputs_expected_': 2}
Since this looks good, we move on to integrating our transformer into our classifier.
[16]:
from sklearn.metrics import accuracy_score
class MultiOutputClassifier(KerasClassifier):
@property
def target_encoder(self):
return MultiOutputTransformer()
@staticmethod
def scorer(y_true, y_pred, **kwargs):
y_bin, y_cat = y_true[:, 0], y_true[:, 1]
y_pred_bin, y_pred_cat = y_pred[:, 0], y_pred[:, 1]
# Keras by default uses the mean of losses of each outputs, so here we do the same
return np.mean([accuracy_score(y_bin, y_pred_bin), accuracy_score(y_cat, y_pred_cat)])
3.3 Test classifier¶
[17]:
from sklearn.preprocessing import StandardScaler
# Use labels as features, just to make sure we can learn correctly
X = y_sklearn
X = StandardScaler().fit_transform(X)
[18]:
clf = MultiOutputClassifier(model=get_clf_model, verbose=0, random_state=0)
clf.fit(X, y_sklearn).score(X, y_sklearn)
[18]:
0.43
4. Multiple inputs¶
The process for multiple inputs is similar, but instead of overriding the transformer in target_encoder
we override feature_encoder
.
[19]:
if False:
from sklearn.base import BaseEstimator, TransformerMixin
class MultiInputTransformer(BaseEstimator, TransformerMixin):
...
class MultiInputClassifier(KerasClassifier):
@property
def feature_encoder(self):
return MultiInputTransformer(...)
4.1 Define Keras Model¶
Let’s define a Keras regression Model with 2 inputs:
[20]:
def get_reg_model():
inp1 = keras.layers.Input(shape=(1, ))
inp2 = keras.layers.Input(shape=(1, ))
x1 = keras.layers.Dense(100, activation="relu")(inp1)
x2 = keras.layers.Dense(50, activation="relu")(inp2)
concat = keras.layers.Concatenate(axis=-1)([x1, x2])
out = keras.layers.Dense(1)(concat)
model = keras.Model(inputs=[inp1, inp2], outputs=out)
model.compile(loss="mse")
return model
And test it with a small mock dataset:
[21]:
X = np.random.random(size=(100, 2))
y = np.sum(X, axis=1)
X = np.split(X, 2, axis=1)
model = get_reg_model()
model.fit(X, y, verbose=0)
y_pred = model.predict(X).squeeze()
[22]:
from sklearn.metrics import r2_score
r2_score(y, y_pred)
[22]:
-6.493108064582687
Having verified that our model builds without errors and accepts the inputs types we expect, we move onto integrating a transformer into our SciKeras model.
4.2 Define data transformer¶
Just like for overriding target_encoder
, we just need to define a sklearn transformer and drop it into our SciKeras wrapper. Since we hardcoded the input shapes into our model and do not rely on any transformer-generated metadata, we can simply use sklearn.preprocessing.FunctionTransformer
:
[23]:
from sklearn.preprocessing import FunctionTransformer
class MultiInputRegressor(KerasRegressor):
@property
def feature_encoder(self):
return FunctionTransformer(
func=lambda X: [X[:, 0], X[:, 1]],
)
Note that we did not implement inverse_transform
(that is, we did not pass an inverse_func
argument to FunctionTransformer
) because features are never converted back to their original form.
4.3 Test regressor¶
[24]:
reg = MultiInputRegressor(model=get_reg_model, verbose=0, random_state=0)
X_sklearn = np.column_stack(X)
reg.fit(X_sklearn, y).score(X_sklearn, y)
[24]:
-4.394967296913346
5. Multidimensional inputs with MNIST dataset¶
In this example, we look at how we can use SciKeras to process the MNIST dataset. The dataset is composed of 60,000 images of digits, each of which is a 2D 28x28 image.
The dataset and Keras Model architecture used come from a Keras example. It may be beneficial to understand the Keras model by reviewing that example first.
[25]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train.shape
[25]:
(60000, 28, 28)
The outputs (labels) are numbers 0-9:
[26]:
print(y_train.shape)
print(np.unique(y_train))
(60000,)
[0 1 2 3 4 5 6 7 8 9]
First, we will “flatten” the data into an array of shape (n_samples, 28*28)
(i.e. a 2D array). This will allow us to use sklearn ecosystem utilities, for example, sklearn.preprocessing.MinMaxScaler
.
[27]:
from sklearn.preprocessing import MinMaxScaler
n_samples_train = x_train.shape[0]
n_samples_test = x_test.shape[0]
x_train = x_train.reshape((n_samples_train, -1))
x_test = x_test.reshape((n_samples_test, -1))
x_train = MinMaxScaler().fit_transform(x_train)
x_test = MinMaxScaler().fit_transform(x_test)
[28]:
print(x_train.shape[1:]) # 784 = 28*28
(784,)
[29]:
print(np.min(x_train), np.max(x_train)) # scaled 0-1
0.0 1.0
Of course, in this case, we could have just as easily used numpy functions to scale our data, but we use MinMaxScaler
to demonstrate use of the sklearn ecosystem.
5.1 Define Keras Model¶
Next we will define our Keras model (adapted from keras.io):
[30]:
num_classes = 10
input_shape = (28, 28, 1)
def get_model(meta):
model = keras.Sequential(
[
keras.Input(input_shape),
keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dropout(0.5),
keras.layers.Dense(num_classes, activation="softmax"),
]
)
model.compile(
loss="sparse_categorical_crossentropy"
)
return model
Now let’s define a transformer that we will use to reshape our input from the sklearn shape ((n_samples, 784)
) to the Keras shape (which we will be (n_samples, 28, 28, 1)
).
[31]:
class MultiDimensionalClassifier(KerasClassifier):
@property
def feature_encoder(self):
return FunctionTransformer(
func=lambda X: X.reshape(X.shape[0], *input_shape),
)
[32]:
clf = MultiDimensionalClassifier(
model=get_model,
epochs=10,
batch_size=128,
validation_split=0.1,
random_state=0,
)
5.2 Test¶
Train and score the model (this takes some time)
[33]:
clf.fit(x_train, y_train)
Epoch 1/10
422/422 [==============================] - 27s 63ms/step - loss: 0.3302 - val_loss: 0.0817
Epoch 2/10
422/422 [==============================] - 20s 48ms/step - loss: 0.1088 - val_loss: 0.0615
Epoch 3/10
422/422 [==============================] - 20s 48ms/step - loss: 0.0831 - val_loss: 0.0458
Epoch 4/10
422/422 [==============================] - 20s 47ms/step - loss: 0.0694 - val_loss: 0.0428
Epoch 5/10
422/422 [==============================] - 20s 47ms/step - loss: 0.0615 - val_loss: 0.0396
Epoch 6/10
422/422 [==============================] - 20s 47ms/step - loss: 0.0578 - val_loss: 0.0357
Epoch 7/10
422/422 [==============================] - 20s 48ms/step - loss: 0.0518 - val_loss: 0.0405
Epoch 8/10
422/422 [==============================] - 20s 47ms/step - loss: 0.0481 - val_loss: 0.0355
Epoch 9/10
422/422 [==============================] - 20s 47ms/step - loss: 0.0454 - val_loss: 0.0342
Epoch 10/10
422/422 [==============================] - 20s 47ms/step - loss: 0.0429 - val_loss: 0.0334
[33]:
MultiDimensionalClassifier(
model=<function get_model at 0x7f7b00db0310>
build_fn=None
warm_start=False
random_state=0
optimizer=rmsprop
loss=None
metrics=None
batch_size=128
validation_batch_size=None
verbose=1
callbacks=None
validation_split=0.1
shuffle=True
run_eagerly=False
epochs=10
class_weight=None
)
[34]:
score = clf.score(x_test, y_test)
print(f"Test score (accuracy): {score:.2f}")
79/79 [==============================] - 1s 11ms/step
Test score (accuracy): 0.99