Run in Google Colab

Data Transformers

Keras support many types of input and output data formats, including:

  • Multiple inputs

  • Multiple outputs

  • Higher-dimensional tensors

This notebook walks through an example of the different data transformations and how SciKeras bridges Keras and Scikit-learn. It may be helpful to have a general understanding of the dataflow before tackling these examples, which is available in the data transformer docs.

1. Setup

[1]:
try:
    import scikeras
except ImportError:
    !python -m pip install scikeras

Silence TensorFlow warnings to keep output succint.

[2]:
import warnings
from tensorflow import get_logger
get_logger().setLevel('ERROR')
warnings.filterwarnings("ignore", message="Setting the random state for TF")
[3]:
import numpy as np
from scikeras.wrappers import KerasClassifier, KerasRegressor
from tensorflow import keras

<<<<<<< HEAD ## 2. Multiple outputs ======= ## 2. Data transformer interface

SciKeras enables advanced Keras use cases by providing an interface to convert sklearn compliant data to whatever format your Keras model requires within SciKeras, right before passing said data to the Keras model.

This interface is implemented in the form of two sklearn transformers, one for the features (X) and one for the target (y). SciKeras loads these transformers via the target_encoder and feature_encoder methods.

By default, SciKeras implements target_encoder for both KerasClassifier and KerasRegressor to facilitate common types of tasks in sklearn. The default implementations are scikeras.utils.transformers.ClassifierLabelEncoder and scikeras.utils.transformers.RegressorTargetEncoder for KerasClassifier and KerasRegressor respectively. Information on the types of tasks that these default transformers are able to perform can be found in the SciKeras docs.

Below is an outline of the inner workings of the data transfomer interfaces to help understand when they are called:

[4]:
if False:  # avoid executing pseudocode
    from scikeras.utils.transformers import (
        ClassifierLabelEncoder,
        RegressorTargetEncoder,
    )


    class BaseWrapper:
        def fit(self, X, y):
            self.target_encoder_ = self.target_encoder
            self.feature_encoder_ = self.feature_encoder
            y = self.target_encoder_.fit_transform(y)
            X = self.feature_encoder_.fit_transform(X)
            self.model_.fit(X, y)
            return self

        def predict(self, X):
            X = self.feature_encoder_.transform(X)
            y_pred = self.model_.predict(X)
            return self.target_encoder_.inverse_transform(y_pred)

    class KerasClassifier(BaseWrapper):

        @property
        def target_encoder(self):
            return ClassifierLabelEncoder(loss=self.loss)

        def predict_proba(self, X):
            X = self.feature_encoder_.transform(X)
            y_pred = self.model_.predict(X)
            return self.target_encoder_.inverse_transform(y_pred, return_proba=True)


    class KerasRegressor(BaseWrapper):

        @property
        def target_encoder(self):
            return RegressorTargetEncoder()

To substitute your own data transformation routine, you must subclass the wrappers and override one of the encoder defining functions. You will have access to all attributes of the wrappers, and you can pass these to your transformer, like we do above with loss.

[5]:
from sklearn.base import BaseEstimator, TransformerMixin
[6]:
if False:  # avoid executing pseudocode

    class MultiOutputTransformer(BaseEstimator, TransformerMixin):
        ...


    class MultiOutputClassifier(KerasClassifier):

        @property
        def target_encoder(self):
            return MultiOutputTransformer(...)

2.1 get_metadata method

SciKeras recognized an optional get_metadata on the transformers. get_metadata is expected to return a dicionary of with key strings and arbitrary values. SciKeras will set add these items to the wrappers namespace and make them available to your model building function via the meta keyword argument:

[7]:
if False:  # avoid executing pseudocode

    class MultiOutputTransformer(BaseEstimator, TransformerMixin):
        def get_metadata(self):
            return {"my_param_": "foobarbaz"}


    class MultiOutputClassifier(KerasClassifier):

        @property
        def target_encoder(self):
            return MultiOutputTransformer(...)


    def get_model(meta):
        print(f"Got: {meta['my_param_']}")


    clf = MultiOutputClassifier(model=get_model)
    clf.fit(X, y)  # Got: foobarbaz
    print(clf.my_param_)  # foobarbaz

3. Multiple outputs

master

Keras makes it straight forward to define models with multiple outputs, that is a Model with multiple sets of fully-connected heads at the end of the network. This functionality is only available in the Functional Model and subclassed Model definition modes, and is not available when using Sequential.

In practice, the main thing about Keras models with multiple outputs that you need to know as a SciKeras user is that Keras expects X or y to be a list of arrays/tensors, with one array/tensor for each input/output.

Note that “multiple outputs” in Keras has a slightly different meaning than “multiple outputs” in sklearn. Many tasks that would be considered “multiple output” tasks in sklearn can be mapped to a single “output” in Keras with multiple units. This notebook specifically focuses on the cases that require multiple distinct Keras outputs.

2.1 Define Keras Model

Here we define a simple perceptron that has two outputs, corresponding to one binary classification taks and one multiclass classification task. For example, one output might be “image has car” (binary) and the other might be “color of car in image” (multiclass).

[8]:
def get_clf_model(meta):
    inp = keras.layers.Input(shape=(meta["n_features_in_"]))
    x1 = keras.layers.Dense(100, activation="relu")(inp)
    out_bin = keras.layers.Dense(1, activation="sigmoid")(x1)
    out_cat = keras.layers.Dense(meta["n_classes_"][1], activation="softmax")(x1)
    model = keras.Model(inputs=inp, outputs=[out_bin, out_cat])
    model.compile(
        loss=["binary_crossentropy", "sparse_categorical_crossentropy"]
    )
    return model

Let’s test that this model works with the kind of inputs and outputs we expect.

[9]:
X = np.random.random(size=(100, 10))
y_bin = np.random.randint(0, 2, size=(100,))
y_cat = np.random.randint(0, 5, size=(100, ))
y = [y_bin, y_cat]

# build mock meta
meta = {
    "n_features_in_": 10,
    "n_classes_": [2, 5]  # note that we made this a list, one for each output
}

model = get_clf_model(meta=meta)

model.fit(X, y, verbose=0)
y_pred = model.predict(X)
[10]:
print(y_pred[0][:2, :])
[[0.5036109 ]
 [0.49061587]]
[11]:
print(y_pred[1][:2, :])
[[0.16864142 0.21402779 0.22270538 0.2222759  0.17234947]
 [0.18863055 0.21719907 0.16864772 0.2111154  0.21440724]]

As you can see, our predict output is also a list of arrays, except it contains probabilities instead of the class predictions.

Our data transormer’s job will be to convert from a single numpy array (which is what the sklearn ecosystem works with) to the list of arrays and then back. Additionally, for classifiers, we will want to be able to convert probabilities to class predictions.

We will structure our data on the sklearn side by column-stacking our list of arrays. This works well in this case since we have the same number of datapoints in each array.

2.2 Define output data transformer

Let’s go ahead and protoype this data transformer:

[12]:
from typing import List

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import LabelEncoder


class MultiOutputTransformer(BaseEstimator, TransformerMixin):

    def fit(self, y):
        y_bin, y_cat = y[:, 0], y[:, 1]
        # Create internal encoders to ensure labels are 0, 1, 2...
        self.bin_encoder_ = LabelEncoder()
        self.cat_encoder_ = LabelEncoder()
        # Fit them to the input data
        self.bin_encoder_.fit(y_bin)
        self.cat_encoder_.fit(y_cat)
        # Save the number of classes
        self.n_classes_ = [
            self.bin_encoder_.classes_.size,
            self.cat_encoder_.classes_.size,
        ]
        # Save number of expected outputs in the Keras model
        # SciKeras will automatically use this to do error-checking
        self.n_outputs_expected_ = 2
        return self

    def transform(self, y: np.ndarray) -> List[np.ndarray]:
        y_bin, y_cat = y[:, 0], y[:, 1]
        # Apply transformers to input array
        y_bin = self.bin_encoder_.transform(y_bin)
        y_cat = self.cat_encoder_.transform(y_cat)
        # Split the data into a list
        return [y_bin, y_cat]

    def inverse_transform(self, y: List[np.ndarray], return_proba: bool = False) -> np.ndarray:
        y_pred_proba = y  # rename for clarity, what Keras gives us are probs
        if return_proba:
            return np.column_stack(y_pred_proba, axis=1)
        # Get class predictions from probabilities
        y_pred_bin = (y_pred_proba[0] > 0.5).astype(int).reshape(-1, )
        y_pred_cat = np.argmax(y_pred_proba[1], axis=1)
        # Pass back through LabelEncoder
        y_pred_bin = self.bin_encoder_.inverse_transform(y_pred_bin)
        y_pred_cat = self.cat_encoder_.inverse_transform(y_pred_cat)
        return np.column_stack([y_pred_bin, y_pred_cat])

    def get_metadata(self):
        return {
            "n_classes_": self.n_classes_,
            "n_outputs_expected_": self.n_outputs_expected_,
        }

Note that in addition to the usual transform and inverse_transform methods, we implement the get_metadata method to return the n_classes_ attribute.

Lets test our transformer with the same dataset we previously used to test our model:

[13]:
tf = MultiOutputTransformer()

y_sklearn = np.column_stack(y)

y_keras = tf.fit_transform(y_sklearn)
print("`y`, as will be passed to Keras:")
print([y_keras[0][:4], y_keras[1][:4]])
`y`, as will be passed to Keras:
[array([1, 0, 1, 0]), array([4, 2, 2, 1])]
[14]:
y_pred_sklearn = tf.inverse_transform(y_pred)
print("`y_pred`, as will be returned to sklearn:")
y_pred_sklearn[:5]
`y_pred`, as will be returned to sklearn:
[14]:
array([[1, 2],
       [0, 1],
       [0, 3],
       [1, 1],
       [0, 2]])
[15]:
print(f"metadata = {tf.get_metadata()}")
metadata = {'n_classes_': [2, 5], 'n_outputs_expected_': 2}

Since this looks good, we move on to integrating our transformer into our classifier.

[16]:
from sklearn.metrics import accuracy_score


class MultiOutputClassifier(KerasClassifier):

    @property
    def target_encoder(self):
        return MultiOutputTransformer()

    @staticmethod
    def scorer(y_true, y_pred, **kwargs):
        y_bin, y_cat = y_true[:, 0], y_true[:, 1]
        y_pred_bin, y_pred_cat = y_pred[:, 0], y_pred[:, 1]
        # Keras by default uses the mean of losses of each outputs, so here we do the same
        return np.mean([accuracy_score(y_bin, y_pred_bin), accuracy_score(y_cat, y_pred_cat)])

2.3 Test classifier

[17]:
from sklearn.preprocessing import StandardScaler

# Use labels as features, just to make sure we can learn correctly
X = y_sklearn
X = StandardScaler().fit_transform(X)
[18]:
clf = MultiOutputClassifier(model=get_clf_model, verbose=0, random_state=0)

clf.fit(X, y_sklearn).score(X, y_sklearn)
[18]:
0.365

3. Multiple inputs

The process for multiple inputs is similar, but instead of overriding the transformer in target_encoder we override feature_encoder.

```python .noeval class MultiInputTransformer(BaseEstimator, TransformerMixin): …

class MultiInputClassifier(KerasClassifier): @property def feature_encoder(self): return MultiInputTransformer(…) ```

3.1 Define Keras Model

Let’s define a Keras regression Model with 2 inputs:

[19]:
def get_reg_model():

    inp1 = keras.layers.Input(shape=(1, ))
    inp2 = keras.layers.Input(shape=(1, ))

    x1 = keras.layers.Dense(100, activation="relu")(inp1)
    x2 = keras.layers.Dense(50, activation="relu")(inp2)

    concat = keras.layers.Concatenate(axis=-1)([x1, x2])

    out = keras.layers.Dense(1)(concat)

    model = keras.Model(inputs=[inp1, inp2], outputs=out)
    model.compile(loss="mse")

    return model

And test it with a small mock dataset:

[20]:
X = np.random.random(size=(100, 2))
y = np.sum(X, axis=1)
X = np.split(X, 2, axis=1)

model = get_reg_model()

model.fit(X, y, verbose=0)
y_pred = model.predict(X).squeeze()
[21]:
from sklearn.metrics import r2_score

r2_score(y, y_pred)
[21]:
-5.3344010320539

Having verified that our model builds without errors and accepts the inputs types we expect, we move onto integrating a transformer into our SciKeras model.

3.2 Define data transformer

Just like for overriding target_encoder, we just need to define a sklearn transformer and drop it into our SciKeras wrapper. Since we hardcoded the input shapes into our model and do not rely on any transformer-generated metadata, we can simply use sklearn.preprocessing.FunctionTransformer:

[22]:
from sklearn.preprocessing import FunctionTransformer


class MultiInputRegressor(KerasRegressor):

    @property
    def feature_encoder(self):
        return FunctionTransformer(
            func=lambda X: [X[:, 0], X[:, 1]],
        )

Note that we did not implement inverse_transform (that is, we did not pass an inverse_func argument to FunctionTransformer) because features are never converted back to their original form.

3.3 Test regressor

[23]:
reg = MultiInputRegressor(model=get_reg_model, verbose=0, random_state=0)

X_sklearn = np.column_stack(X)

reg.fit(X_sklearn, y).score(X_sklearn, y)
[23]:
-3.2624139786426314

4. Multidimensional inputs with MNIST dataset

In this example, we look at how we can use SciKeras to process the MNIST dataset. The dataset is composed of 60,000 images of digits, each of which is a 2D 28x28 image.

The dataset and Keras Model architecture used come from a Keras example. It may be beneficial to understand the Keras model by reviewing that example first.

[24]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train.shape
[24]:
(60000, 28, 28)

The outputs (labels) are numbers 0-9:

[25]:
print(y_train.shape)
print(np.unique(y_train))
(60000,)
[0 1 2 3 4 5 6 7 8 9]

First, we will “flatten” the data into an array of shape (n_samples, 28*28) (i.e. a 2D array). This will allow us to use sklearn ecosystem utilities, for example, sklearn.preprocessing.MinMaxScaler.

[26]:
from sklearn.preprocessing import MinMaxScaler

n_samples_train = x_train.shape[0]
n_samples_test = x_test.shape[0]

x_train = x_train.reshape((n_samples_train, -1))
x_test = x_test.reshape((n_samples_test, -1))
x_train = MinMaxScaler().fit_transform(x_train)
x_test = MinMaxScaler().fit_transform(x_test)

# reduce dataset size for faster training
n_samples = 1000
x_train, y_train, x_test, y_test = x_train[:n_samples], y_train[:n_samples], x_test[:n_samples], y_test[:n_samples]
[27]:
print(x_train.shape[1:])  # 784 = 28*28
(784,)
[28]:
print(np.min(x_train), np.max(x_train))  # scaled 0-1
0.0 1.0

Of course, in this case, we could have just as easily used numpy functions to scale our data, but we use MinMaxScaler to demonstrate use of the sklearn ecosystem.

4.1 Define Keras Model

Next we will define our Keras model (adapted from keras.io):

[29]:
num_classes = 10
input_shape = (28, 28, 1)


def get_model(meta):
    model = keras.Sequential(
        [
            keras.Input(input_shape),
            keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
            keras.layers.MaxPooling2D(pool_size=(2, 2)),
            keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
            keras.layers.MaxPooling2D(pool_size=(2, 2)),
            keras.layers.Flatten(),
            keras.layers.Dropout(0.5),
            keras.layers.Dense(num_classes, activation="softmax"),
        ]
    )
    model.compile(
        loss="sparse_categorical_crossentropy"
    )
    return model

Now let’s define a transformer that we will use to reshape our input from the sklearn shape ((n_samples, 784)) to the Keras shape (which we will be (n_samples, 28, 28, 1)).

[30]:
class MultiDimensionalClassifier(KerasClassifier):

    @property
    def feature_encoder(self):
        return FunctionTransformer(
            func=lambda X: X.reshape(X.shape[0], *input_shape),
        )
[31]:
clf = MultiDimensionalClassifier(
    model=get_model,
    epochs=10,
    batch_size=128,
    validation_split=0.1,
    random_state=0,
)

4.2 Test

Train and score the model (this takes some time)

[32]:
_ = clf.fit(x_train, y_train)
Epoch 1/10
8/8 [==============================] - 2s 205ms/step - loss: 2.2129 - val_loss: 1.9378
Epoch 2/10
8/8 [==============================] - 1s 99ms/step - loss: 1.7355 - val_loss: 1.4332
Epoch 3/10
8/8 [==============================] - 1s 101ms/step - loss: 1.2980 - val_loss: 1.1264
Epoch 4/10
8/8 [==============================] - 1s 100ms/step - loss: 0.9632 - val_loss: 0.8725
Epoch 5/10
8/8 [==============================] - 1s 96ms/step - loss: 0.7774 - val_loss: 0.7834
Epoch 6/10
8/8 [==============================] - 1s 100ms/step - loss: 0.6900 - val_loss: 0.8330
Epoch 7/10
8/8 [==============================] - 1s 89ms/step - loss: 0.6336 - val_loss: 0.6026
Epoch 8/10
8/8 [==============================] - 1s 87ms/step - loss: 0.5648 - val_loss: 0.7786
Epoch 9/10
8/8 [==============================] - 1s 101ms/step - loss: 0.5365 - val_loss: 0.5447
Epoch 10/10
8/8 [==============================] - 1s 98ms/step - loss: 0.4750 - val_loss: 0.6588
[33]:
score = clf.score(x_test, y_test)
print(f"Test score (accuracy): {score:.2f}")
8/8 [==============================] - 0s 23ms/step
Test score (accuracy): 0.76

5. Ragged datasets with tf.data.Dataset

SciKeras provides a third dependency injection point that operates on the entire dataset: X, y & sample_weight. This dataset_transformer is applied after target_transformer and feature_transformer. One use case for this dependency injection point is to transform data from tabular/array-like to the tf.data.Dataset format, which only requires iteration. We can use this to create a tf.data.Dataset of ragged tensors.

Note that dataset_transformer should accept a single single dictionary as its argument to transform and fit, and return a single dictionary as well. More details on this are in the docs.

Let’s start by defining our data. We’ll have an extra “feature” that marks the observation index, but we’ll remove it when we deconstruct our data in the transformer.

[34]:
feature_1 = np.random.uniform(size=(10, ))
feature_2 = np.random.uniform(size=(10, ))
obs = [0, 0, 0, 1, 1, 2, 3, 3, 4, 4]

X = np.column_stack([feature_1, feature_2, obs]).astype("float32")

y = np.array(["class1"] * 5 + ["class2"] * 5, dtype=str)

Next, we define our dataset_transformer. We will do this by defining a custom forward transformation outside of the Keras model. Note that we do not define an inverse transformation since that is never used. Also note that dataset_transformer will always be called with X (i.e. the first element of the tuple will always be populated), but will be called with y=None when used for predict. Thus, you should check if y and sample_weigh are None before doing any operations on them.

[35]:
from typing import Dict, Any

import tensorflow as tf


def ragged_transformer(data: Dict[str, Any]) -> Dict[str, Any]:
    x, y, sample_weight = data["x"], data.get("y", None), data.get("sample_weight", None)
    if y is not None:
        y = y.reshape(-1, 1 if len(y.shape) == 1 else y.shape[1])
        y = y[tf.RaggedTensor.from_value_rowids(y, x[:, -1]).row_starts().numpy()]
    if sample_weight is not None:
        sample_weight = sample_weight.reshape(-1, 1 if len(sample_weight.shape) == 1 else sample_weight.shape[1])
        sample_weight = sample_weight[tf.RaggedTensor.from_value_rowids(sample_weight, x[:, -1]).row_starts().numpy()]
    x = tf.RaggedTensor.from_value_rowids(x[:, :-1], x[:, -1])
    data["x"] = x
    if "y" in data:
        data["y"] = y
    if "sample_weight" in data:
        data["sample_weight"] = sample_weight
    return data

In this case, we chose to keep y and sample_weight as numpy arrays, which will allow us to re-use ClassWeightDataTransformer, the default dataset_transformer for KerasClassifier.

Lets quickly test our transformer:

[36]:
data = ragged_transformer(dict(x=X, y=y, sample_weight=None))
print(type(data["x"]))
print(data["x"].shape)
<class 'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor'>
(5, None, 2)

And the y=None case:

[37]:
data = ragged_transformer(dict(x=X, y=None, sample_weight=None))
print(type(data["x"]))
print(data["x"].shape)
<class 'tensorflow.python.ops.ragged.ragged_tensor.RaggedTensor'>
(5, None, 2)

Everything looks good!

Because Keras will not accept a RaggedTensor directly, we will need to wrap our entire dataset into a tensorflow Dataset. We can do this by adding one more transformation step:

Next, we can add our transormers to our model. We use an sklearn Pipeline (generated via make_pipeline) to keep ClassWeightDataTransformer operational while implementing our custom transformation.

[38]:
def dataset_transformer(data: Dict[str, Any]) -> Dict[str, Any]:
    x_y_s = data["x"], data.get("y", None), data.get("sample_weight", None)
    data["x"] = tf.data.Dataset.from_tensor_slices(x_y_s)
    # don't blindly assign y & sw; if being called from
    # predict they should not just be None, they should not be present at all!
    if "y" in data:
        data["y"] = None
    if "sample_weight" in data:
        data["sample_weight"] = None
    return data
[39]:
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import make_pipeline


class RaggedClassifier(KerasClassifier):

    @property
    def dataset_transformer(self):
        t1 = FunctionTransformer(ragged_transformer)
        t2 = super().dataset_transformer  # ClassWeightDataTransformer
        t3 = FunctionTransformer(dataset_transformer)
        t4 = "passthrough"  # see https://scikit-learn.org/stable/modules/compose.html#pipeline-chaining-estimators
        return make_pipeline(t1, t2, t3, t4)

Now we can define a Model. We need some way to handle/flatten our ragged arrays within our model. For this example, we use a custom mean layer, but you could use an Embedding layer, LSTM, etc.

[40]:
from tensorflow import reduce_mean, reshape
from tensorflow.keras import Sequential, layers


class CustomMean(layers.Layer):

    def __init__(self, axis=None):
        super(CustomMean, self).__init__()
        self._supports_ragged_inputs = True
        self.axis = axis

    def call(self, inputs, **kwargs):
        input_shape = inputs.get_shape()
        return reshape(reduce_mean(inputs, axis=self.axis), (1, *input_shape[1:]))


def get_model(meta):
    inp_shape = meta["X_shape_"][1]-1
    model = Sequential([
        layers.Input(shape=(inp_shape,), ragged=True),
        CustomMean(axis=0),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

And attach our model to our classifier wrapper:

[41]:
clf = RaggedClassifier(get_model, loss="bce")

Finally, let’s train and predict:

[42]:
clf.fit(X, y)
y_pred = clf.predict(X)
y_pred
5/5 [==============================] - 0s 3ms/step - loss: 0.6816
5/5 [==============================] - 0s 2ms/step
[42]:
array(['class1', 'class1', 'class1', 'class1', 'class1'], dtype='<U6')

If we define our custom layers, transformers and wrappers in their own module, we can easily create a self-contained classifier that is able to handle ragged datasets and has a clean Scikit-Learn compatible API:

[43]:
class RaggedClassifier(KerasClassifier):

    @property
    def dataset_transformer(self):
        t1 = FunctionTransformer(ragged_transformer)
        t2 = super().dataset_transformer  # ClassWeightDataTransformer
        t3 = FunctionTransformer(dataset_transformer)
        t4 = "passthrough"  # see https://scikit-learn.org/stable/modules/compose.html#pipeline-chaining-estimators
        return make_pipeline(t1, t2, t3, t4)

    def _keras_build_fn(self):
        inp_shape = self.X_shape_[1] - 1
        model = Sequential([
            layers.Input(shape=(inp_shape,), ragged=True),
            CustomMean(axis=0),
            layers.Dense(1, activation='sigmoid')
        ])
        return model
[44]:
clf = RaggedClassifier(loss="bce")
clf.fit(X, y)
y_pred = clf.predict(X)
y_pred
5/5 [==============================] - 0s 1ms/step - loss: 0.7404
5/5 [==============================] - 0s 2ms/step
[44]:
array(['class2', 'class2', 'class2', 'class2', 'class2'], dtype='<U6')

6. Multi-output class_weight

In this example, we will use dataset_transformer to support multi-output class weights. We will re-use our MultiOutputTransformer from our previous example to split the output, then we will create sample_weight from class_weight.

[45]:
from collections import defaultdict
from typing import Union

from sklearn.utils.class_weight import compute_sample_weight


class DatasetTransformer(BaseEstimator, TransformerMixin):

    def __init__(self, output_names):
        self.output_names = output_names

    def fit(self, data: Dict[str, Any]) -> "DatasetTransformer":
        return self

    def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
        class_weight = data.get("class_weight", None)
        if class_weight is None:
            return data
        if isinstance(class_weight, str):  # handle "balanced"
            class_weight_ = class_weight
            class_weight = defaultdict(lambda: class_weight_)
        y, sample_weight = data.get("y", None), data.get("sample_weight", None)
        assert sample_weight is None, "Cannot use class_weight & sample_weight together"
        if y is not None:
            # y should be a list of arrays, as split up by MultiOutputTransformer
            sample_weight = {
                output_name: compute_sample_weight(class_weight[output_num], output_data)
                for output_num, (output_name, output_data) in enumerate(zip(self.output_names, y))
            }
            # Note: class_weight is expected to be indexable by output_number in sklearn
            # see https://scikit-learn.org/stable/modules/generated/sklearn.utils.class_weight.compute_sample_weight.html
            # It is trivial to change the expected format to match Keras' ({output_name: weights, ...})
            # see https://github.com/keras-team/keras/issues/4735#issuecomment-267473722
            data["sample_weight"] = sample_weight
            data["class_weight"] = None
        return data


def get_model(meta, compile_kwargs):
    inp = keras.layers.Input(shape=(meta["n_features_in_"]))
    x1 = keras.layers.Dense(100, activation="relu")(inp)
    out_bin = keras.layers.Dense(1, activation="sigmoid")(x1)
    out_cat = keras.layers.Dense(meta["n_classes_"][1], activation="softmax")(x1)
    model = keras.Model(inputs=inp, outputs=[out_bin, out_cat])
    model.compile(
        loss=["binary_crossentropy", "sparse_categorical_crossentropy"],
        optimizer=compile_kwargs["optimizer"]
    )
    return model


class CustomClassifier(KerasClassifier):

    @property
    def target_encoder(self):
        return MultiOutputTransformer()

    @property
    def dataset_transformer(self):
        return DatasetTransformer(
            output_names=self.model_.output_names,
        )

Next, we define the data. We’ll use sklearn.datasets.make_blobs to generate a relatively noisy dataset:

[46]:
from sklearn.datasets import make_blobs


X, y = make_blobs(centers=3, random_state=0, cluster_std=20)
# make a binary target for "is the value of the first class?"
y_bin = y == y[0]
y = np.column_stack([y_bin, y])

Test the model without specifying class weighting:

[47]:
clf = CustomClassifier(get_model, epochs=100, verbose=0, random_state=0)
clf.fit(X, y)
y_pred = clf.predict(X)
(_, counts_bin) = np.unique(y_pred[:, 0], return_counts=True)
print(counts_bin)
(_, counts_cat) = np.unique(y_pred[:, 1], return_counts=True)
print(counts_cat)
[91  9]
[28 30 42]

As you can see, without class_weight="balanced", our classifier only predicts mainly a single class for the first output. Now with class_weight="balanced":

[48]:
clf = CustomClassifier(get_model, class_weight="balanced", epochs=100, verbose=0, random_state=0)
clf.fit(X, y)
y_pred = clf.predict(X)
(_, counts_bin) = np.unique(y_pred[:, 0], return_counts=True)
print(counts_bin)
(_, counts_cat) = np.unique(y_pred[:, 1], return_counts=True)
print(counts_cat)
[57 43]
[27 27 46]

Now, we get (mostly) balanced classes. But what if we want to specify our classes manually? You will notice that in when we defined DatasetTransformer, we gave it the ability to handle a list of class weights. For demonstration purposes, we will highly bias towards the second class in each output:

[49]:
clf = CustomClassifier(get_model, class_weight=[{0: 0.1, 1: 1}, {0: 0.1, 1: 1, 2: 0.1}], epochs=100, verbose=0, random_state=0)
clf.fit(X, y)
y_pred = clf.predict(X)
(_, counts_bin) = np.unique(y_pred[:, 0], return_counts=True)
print(counts_bin)
(_, counts_cat) = np.unique(y_pred[:, 1], return_counts=True)
print(counts_cat)
[ 7 93]
[ 2 98]

Or mixing the two methods, because our first output is unbalanced but our second is (presumably) balanced:

[50]:
clf = CustomClassifier(get_model, class_weight=["balanced", None], epochs=100, verbose=0, random_state=0)
clf.fit(X, y)
y_pred = clf.predict(X)
(_, counts_bin) = np.unique(y_pred[:, 0], return_counts=True)
print(counts_bin)
(_, counts_cat) = np.unique(y_pred[:, 1], return_counts=True)
print(counts_cat)
[57 43]
[30 25 45]

7. Custom validation dataset

Although dataset_transformer is primarily designed for data transformations, because it returns valid **kwargs to fit it can be used for other advanced use cases. In this example, we use dataset_transformer to implement a custom test/train split for Keras’ internal validation. We’ll use sklearn’s train_test_split, but this could be implemented via an arbitrary user function, eg. to ensure balanced class distribution.

[51]:
from sklearn.model_selection import train_test_split


def get_clf(meta: Dict[str, Any]):
    inp = keras.layers.Input(shape=(meta["n_features_in_"],))
    x1 = keras.layers.Dense(100, activation="relu")(inp)
    out = keras.layers.Dense(1, activation="sigmoid")(x1)
    return keras.Model(inputs=inp, outputs=out)


class CustomSplit(BaseEstimator, TransformerMixin):

    def __init__(self, test_size: float):
        self.test_size = test_size

    def fit(self, data: Dict[str, Any]) -> "CustomSplit":
        return self

    def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
        if self.test_size == 0:
            return data
        x, y, sw = data["x"], data.get("y", None), data.get("sample_weight", None)
        if y is None:
            return data
        if sw is None:
            x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=self.test_size, stratify=y)
            validation_data = (x_val, y_val)
            sw_train = None
        else:
            x_train, x_val, y_train, y_val, sw_train, sw_val = train_test_split(x, y, sw, test_size=self.test_size, stratify=y)
            validation_data = (x_val, y_val, sw_val)
        data["validation_data"] = validation_data
        data["x"], data["y"], data["sample_weight"] = x_train, y_train, sw_train
        return data


class CustomClassifier(KerasClassifier):

    @property
    def dataset_transformer(self):
        return CustomSplit(test_size=self.validation_split)

And now lets test with a toy dataset. We specifically choose to make the target strings to show that with this approach, we can preserve all of the nice data pre-processing that SciKeras does for us, while still being able to split the final data before passing it to Keras.

[52]:
y = np.array(["a"] * 900 + ["b"] * 100)
X = np.array([0] * 900 + [1] * 100).reshape(-1, 1)

To get a base measurment to compare against, we’ll run first with KerasClassifier as a benchmark.

[53]:
clf = KerasClassifier(
    get_clf,
    loss="bce",
    metrics=["binary_accuracy"],
    verbose=False,
    validation_split=0.1,
    shuffle=False,
    random_state=0,
    epochs=10
)

clf.fit(X, y)
print(f"binary_accuracy = {clf.history_['binary_accuracy'][-1]}")
print(f"val_binary_accuracy = {clf.history_['val_binary_accuracy'][-1]}")
binary_accuracy = 1.0
val_binary_accuracy = 0.0

We see that we get near zero validation accuracy. Because one of our classes was only found in the tail end of our dataset and we specified validation_split=0.1, we validated with a class we had never seen before.

We could specify shuffle=True (this is actually the default), but for highly imbalanced classes, this may not be as good as stratified splitting.

So lets test our new CustomClassifier.

[54]:
clf = CustomClassifier(
    get_clf,
    loss="bce",
    metrics=["binary_accuracy"],
    verbose=False,
    validation_split=0.1,
    shuffle=False,
    random_state=0,
    epochs=10
)

clf.fit(X, y)
print(f"binary_accuracy = {clf.history_['binary_accuracy'][-1]}")
print(f"val_binary_accuracy = {clf.history_['val_binary_accuracy'][-1]}")
binary_accuracy = 1.0
val_binary_accuracy = 1.0

Much better!

8. Dynamically setting batch_size

In this tutorial, we use the data_transformer interface to implement a dynamic batch_size, similar to sklearn’s MLPClassifier. We will implement batch_size as batch_size=min(200, n_samples).

[55]:
from sklearn.model_selection import train_test_split


def check_batch_size(x):
    """Check the batch_size used in training.
    """
    bs = x.shape[0]
    if bs is not None:
        print(f"batch_size={bs}")
    return x


def get_clf(meta: Dict[str, Any]):
    inp = keras.layers.Input(shape=(meta["n_features_in_"],))
    x1 = keras.layers.Dense(100, activation="relu")(inp)
    x2 = keras.layers.Lambda(check_batch_size)(x1)
    out = keras.layers.Dense(1, activation="sigmoid")(x2)
    return keras.Model(inputs=inp, outputs=out)


class DynamicBatch(BaseEstimator, TransformerMixin):

    def fit(self, data: Dict[str, Any]) -> "DynamicBatch":
        return self

    def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:
        n_samples = data["x"].shape[0]
        data["batch_size"] = min(200, n_samples)
        return data


class DynamicBatchClassifier(KerasClassifier):

    @property
    def dataset_transformer(self):
        return DynamicBatch()

Since this is happening inside SciKeras, this will work even if we are doing cross validation (which adjusts the split according to cv).

[56]:
from sklearn.model_selection import cross_val_score

clf = DynamicBatchClassifier(
    get_clf,
    loss="bce",
    verbose=False,
    random_state=0
)

_ = cross_val_score(clf, X, y, cv=6)  # note: 1000 / 6 = 167
batch_size=167
batch_size=167
batch_size=167
batch_size=167
batch_size=166
batch_size=166

But if we train with larger inputs, we can hit the cap of 200 we set:

[57]:
_ = cross_val_score(clf, X, y, cv=5)
batch_size=200
batch_size=200
batch_size=200
batch_size=200
batch_size=200
batch_size=200
batch_size=200
batch_size=200
batch_size=200
batch_size=200
batch_size=200
batch_size=200
batch_size=200
batch_size=200
batch_size=200