{ "cells": [ { "cell_type": "raw", "id": "breeding-reader", "metadata": {}, "source": [ "Run in Google Colab" ] }, { "cell_type": "markdown", "id": "mysterious-charity", "metadata": {}, "source": [ "# Data Transformers\n", "\n", "Keras support many types of input and output data formats, including:\n", "\n", "* Multiple inputs\n", "* Multiple outputs\n", "* Higher-dimensional tensors\n", "\n", "This notebook walks through an example of the different data transformations and how SciKeras bridges Keras and Scikit-learn.\n", "It may be helpful to have a general understanding of the dataflow before tackling these examples, which is available in\n", "the [data transformer docs](https://www.adriangb.com/scikeras/refs/heads/master/advanced.html#data-transformers).\n", "\n", "## Table of contents\n", "\n", "* [1. Setup](#1.-Setup)\n", "* [2. Multiple outputs](#2.-Multiple-outputs)\n", " * [2.1 Define Keras Model](#2.1-Define-Keras-Model)\n", " * [2.2 Define output data transformer](#2.2-Define-output-data-transformer)\n", " * [2.3 Test classifier](#2.3-Test-classifier)\n", "* [3. Multiple inputs](#3-multiple-inputs)\n", " * [3.1 Define Keras Model](#3.1-Define-Keras-Model)\n", " * [3.2 Define data transformer](#3.2-Define-data-transformer)\n", " * [3.3 Test regressor](#3.3-Test-regressor)\n", "* [4. Multidimensional inputs with MNIST dataset](#4.-Multidimensional-inputs-with-MNIST-dataset)\n", " * [4.1 Define Keras Model](#4.1-Define-Keras-Model)\n", " * [4.2 Test](#4.2-Test)\n", "* [5. Ragged datasets with tf.data.Dataset](#5.-Ragged-datasets-with-tf.data.Dataset)\n", "* [6. Multi-output class_weight](#6.-Multi-output-class_weight)\n", "* [7. Custom validation dataset](#6.-Custom-validation-dataset)\n", "* [8. Dynamically setting batch_size](#6.-Dynamically-setting-batch_size)" ] }, { "cell_type": "markdown", "id": "compact-cutting", "metadata": {}, "source": [ "## 1. Setup" ] }, { "cell_type": "code", "execution_count": 1, "id": "junior-peter", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:17:59.989138Z", "iopub.status.busy": "2021-02-20T16:17:59.988371Z", "iopub.status.idle": "2021-02-20T16:18:02.144746Z", "shell.execute_reply": "2021-02-20T16:18:02.142135Z" } }, "outputs": [], "source": [ "try:\n", " import scikeras\n", "except ImportError:\n", " !python -m pip install scikeras" ] }, { "cell_type": "markdown", "id": "impossible-virus", "metadata": {}, "source": [ "Silence TensorFlow warnings to keep output succint." ] }, { "cell_type": "code", "execution_count": 2, "id": "intended-article", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:02.149615Z", "iopub.status.busy": "2021-02-20T16:18:02.149072Z", "iopub.status.idle": "2021-02-20T16:18:02.153752Z", "shell.execute_reply": "2021-02-20T16:18:02.152969Z" } }, "outputs": [], "source": [ "import warnings\n", "from tensorflow import get_logger\n", "get_logger().setLevel('ERROR')\n", "warnings.filterwarnings(\"ignore\", message=\"Setting the random state for TF\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "talented-layer", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:02.157785Z", "iopub.status.busy": "2021-02-20T16:18:02.157255Z", "iopub.status.idle": "2021-02-20T16:18:02.502737Z", "shell.execute_reply": "2021-02-20T16:18:02.503593Z" } }, "outputs": [], "source": [ "import numpy as np\n", "from scikeras.wrappers import KerasClassifier, KerasRegressor\n", "from tensorflow import keras" ] }, { "cell_type": "markdown", "id": "european-headset", "metadata": {}, "source": [ "<<<<<<< HEAD\n", "## 2. Multiple outputs\n", "=======\n", "## 2. Data transformer interface\n", "\n", "SciKeras enables advanced Keras use cases by providing an interface to convert sklearn compliant data to whatever format your Keras model requires within SciKeras, right before passing said data to the Keras model.\n", "\n", "This interface is implemented in the form of two sklearn transformers, one for the features (`X`) and one for the target (`y`). SciKeras loads these transformers via the `target_encoder` and `feature_encoder` methods.\n", "\n", "By default, SciKeras implements `target_encoder` for both KerasClassifier and KerasRegressor to facilitate common types of tasks in sklearn. The default implementations are `scikeras.utils.transformers.ClassifierLabelEncoder` and `scikeras.utils.transformers.RegressorTargetEncoder` for KerasClassifier and KerasRegressor respectively. Information on the types of tasks that these default transformers are able to perform can be found in the [SciKeras docs](https://www.adriangb.com/scikeras/stable/advanced.html#data-transformers).\n", "\n", "Below is an outline of the inner workings of the data transfomer interfaces to help understand when they are called:" ] }, { "cell_type": "code", "execution_count": 4, "id": "demonstrated-contest", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:02.506934Z", "iopub.status.busy": "2021-02-20T16:18:02.505773Z", "iopub.status.idle": "2021-02-20T16:18:02.515933Z", "shell.execute_reply": "2021-02-20T16:18:02.516356Z" } }, "outputs": [], "source": [ "if False: # avoid executing pseudocode\n", " from scikeras.utils.transformers import (\n", " ClassifierLabelEncoder,\n", " RegressorTargetEncoder,\n", " )\n", "\n", "\n", " class BaseWrapper:\n", " def fit(self, X, y):\n", " self.target_encoder_ = self.target_encoder\n", " self.feature_encoder_ = self.feature_encoder\n", " y = self.target_encoder_.fit_transform(y)\n", " X = self.feature_encoder_.fit_transform(X)\n", " self.model_.fit(X, y)\n", " return self\n", " \n", " def predict(self, X):\n", " X = self.feature_encoder_.transform(X)\n", " y_pred = self.model_.predict(X)\n", " return self.target_encoder_.inverse_transform(y_pred)\n", "\n", " class KerasClassifier(BaseWrapper):\n", "\n", " @property\n", " def target_encoder(self):\n", " return ClassifierLabelEncoder(loss=self.loss)\n", " \n", " def predict_proba(self, X):\n", " X = self.feature_encoder_.transform(X)\n", " y_pred = self.model_.predict(X)\n", " return self.target_encoder_.inverse_transform(y_pred, return_proba=True)\n", "\n", "\n", " class KerasRegressor(BaseWrapper):\n", "\n", " @property\n", " def target_encoder(self):\n", " return RegressorTargetEncoder()" ] }, { "cell_type": "markdown", "id": "raising-anthony", "metadata": {}, "source": [ "To substitute your own data transformation routine, you must subclass the wrappers and override one of the encoder defining functions. You will have access to all attributes of the wrappers, and you can pass these to your transformer, like we do above with `loss`." ] }, { "cell_type": "code", "execution_count": 5, "id": "considered-repository", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:02.520687Z", "iopub.status.busy": "2021-02-20T16:18:02.520145Z", "iopub.status.idle": "2021-02-20T16:18:02.521727Z", "shell.execute_reply": "2021-02-20T16:18:02.522454Z" } }, "outputs": [], "source": [ "from sklearn.base import BaseEstimator, TransformerMixin" ] }, { "cell_type": "code", "execution_count": 6, "id": "theoretical-surge", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:02.527292Z", "iopub.status.busy": "2021-02-20T16:18:02.526745Z", "iopub.status.idle": "2021-02-20T16:18:02.529535Z", "shell.execute_reply": "2021-02-20T16:18:02.530234Z" } }, "outputs": [], "source": [ "if False: # avoid executing pseudocode\n", "\n", " class MultiOutputTransformer(BaseEstimator, TransformerMixin):\n", " ...\n", "\n", "\n", " class MultiOutputClassifier(KerasClassifier):\n", "\n", " @property\n", " def target_encoder(self):\n", " return MultiOutputTransformer(...)" ] }, { "cell_type": "markdown", "id": "cardiovascular-process", "metadata": {}, "source": [ "### 2.1 get_metadata method\n", "\n", "SciKeras recognized an optional `get_metadata` on the transformers. `get_metadata` is expected to return a dicionary of with key strings and arbitrary values. SciKeras will set add these items to the wrappers namespace and make them available to your model building function via the `meta` keyword argument:" ] }, { "cell_type": "code", "execution_count": 7, "id": "figured-office", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:02.536205Z", "iopub.status.busy": "2021-02-20T16:18:02.535650Z", "iopub.status.idle": "2021-02-20T16:18:02.539109Z", "shell.execute_reply": "2021-02-20T16:18:02.538610Z" } }, "outputs": [], "source": [ "if False: # avoid executing pseudocode\n", "\n", " class MultiOutputTransformer(BaseEstimator, TransformerMixin):\n", " def get_metadata(self):\n", " return {\"my_param_\": \"foobarbaz\"}\n", "\n", "\n", " class MultiOutputClassifier(KerasClassifier):\n", "\n", " @property\n", " def target_encoder(self):\n", " return MultiOutputTransformer(...)\n", "\n", "\n", " def get_model(meta):\n", " print(f\"Got: {meta['my_param_']}\")\n", "\n", "\n", " clf = MultiOutputClassifier(model=get_model)\n", " clf.fit(X, y) # Got: foobarbaz\n", " print(clf.my_param_) # foobarbaz" ] }, { "cell_type": "markdown", "id": "frozen-belgium", "metadata": {}, "source": [ "## 3. Multiple outputs\n", ">>>>>>> master\n", "\n", "Keras makes it straight forward to define models with multiple outputs, that is a Model with multiple sets of fully-connected heads at the end of the network. This functionality is only available in the Functional Model and subclassed Model definition modes, and is not available when using Sequential.\n", "\n", "In practice, the main thing about Keras models with multiple outputs that you need to know as a SciKeras user is that Keras expects `X` or `y` to be a list of arrays/tensors, with one array/tensor for each input/output.\n", "\n", "Note that \"multiple outputs\" in Keras has a slightly different meaning than \"multiple outputs\" in sklearn. Many tasks that would be considered \"multiple output\" tasks in sklearn can be mapped to a single \"output\" in Keras with multiple units. This notebook specifically focuses on the cases that require multiple distinct Keras outputs.\n", "\n", "### 2.1 Define Keras Model\n", "\n", "Here we define a simple perceptron that has two outputs, corresponding to one binary classification taks and one multiclass classification task. For example, one output might be \"image has car\" (binary) and the other might be \"color of car in image\" (multiclass)." ] }, { "cell_type": "code", "execution_count": 8, "id": "reported-matthew", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:02.545447Z", "iopub.status.busy": "2021-02-20T16:18:02.544881Z", "iopub.status.idle": "2021-02-20T16:18:02.547896Z", "shell.execute_reply": "2021-02-20T16:18:02.548338Z" } }, "outputs": [], "source": [ "def get_clf_model(meta):\n", " inp = keras.layers.Input(shape=(meta[\"n_features_in_\"]))\n", " x1 = keras.layers.Dense(100, activation=\"relu\")(inp)\n", " out_bin = keras.layers.Dense(1, activation=\"sigmoid\")(x1)\n", " out_cat = keras.layers.Dense(meta[\"n_classes_\"][1], activation=\"softmax\")(x1)\n", " model = keras.Model(inputs=inp, outputs=[out_bin, out_cat])\n", " model.compile(\n", " loss=[\"binary_crossentropy\", \"sparse_categorical_crossentropy\"]\n", " )\n", " return model" ] }, { "cell_type": "markdown", "id": "binary-vector", "metadata": {}, "source": [ "Let's test that this model works with the kind of inputs and outputs we expect." ] }, { "cell_type": "code", "execution_count": 9, "id": "extraordinary-baltimore", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:02.554760Z", "iopub.status.busy": "2021-02-20T16:18:02.554217Z", "iopub.status.idle": "2021-02-20T16:18:03.474115Z", "shell.execute_reply": "2021-02-20T16:18:03.476509Z" } }, "outputs": [], "source": [ "X = np.random.random(size=(100, 10))\n", "y_bin = np.random.randint(0, 2, size=(100,))\n", "y_cat = np.random.randint(0, 5, size=(100, ))\n", "y = [y_bin, y_cat]\n", "\n", "# build mock meta\n", "meta = {\n", " \"n_features_in_\": 10,\n", " \"n_classes_\": [2, 5] # note that we made this a list, one for each output\n", "}\n", "\n", "model = get_clf_model(meta=meta)\n", "\n", "model.fit(X, y, verbose=0)\n", "y_pred = model.predict(X)" ] }, { "cell_type": "code", "execution_count": 10, "id": "subtle-oregon", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:03.482721Z", "iopub.status.busy": "2021-02-20T16:18:03.481596Z", "iopub.status.idle": "2021-02-20T16:18:03.488648Z", "shell.execute_reply": "2021-02-20T16:18:03.489391Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0.5036109 ]\n", " [0.49061587]]\n" ] } ], "source": [ "print(y_pred[0][:2, :])" ] }, { "cell_type": "code", "execution_count": 11, "id": "behind-range", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:03.492876Z", "iopub.status.busy": "2021-02-20T16:18:03.491820Z", "iopub.status.idle": "2021-02-20T16:18:03.498371Z", "shell.execute_reply": "2021-02-20T16:18:03.499306Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0.16864142 0.21402779 0.22270538 0.2222759 0.17234947]\n", " [0.18863055 0.21719907 0.16864772 0.2111154 0.21440724]]\n" ] } ], "source": [ "print(y_pred[1][:2, :])" ] }, { "cell_type": "markdown", "id": "chubby-plane", "metadata": {}, "source": [ "As you can see, our `predict` output is also a list of arrays, except it contains probabilities instead of the class predictions.\n", "\n", "Our data transormer's job will be to convert from a single numpy array (which is what the sklearn ecosystem works with) to the list of arrays and then back. Additionally, for classifiers, we will want to be able to convert probabilities to class predictions.\n", "\n", "We will structure our data on the sklearn side by column-stacking our list\n", "of arrays. This works well in this case since we have the same number of datapoints in each array.\n", "\n", "### 2.2 Define output data transformer\n", "\n", "Let's go ahead and protoype this data transformer:" ] }, { "cell_type": "code", "execution_count": 12, "id": "american-corps", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:03.502902Z", "iopub.status.busy": "2021-02-20T16:18:03.502340Z", "iopub.status.idle": "2021-02-20T16:18:03.514309Z", "shell.execute_reply": "2021-02-20T16:18:03.515108Z" } }, "outputs": [], "source": [ "from typing import List\n", "\n", "from sklearn.base import BaseEstimator, TransformerMixin\n", "from sklearn.preprocessing import LabelEncoder\n", "\n", "\n", "class MultiOutputTransformer(BaseEstimator, TransformerMixin):\n", "\n", " def fit(self, y):\n", " y_bin, y_cat = y[:, 0], y[:, 1]\n", " # Create internal encoders to ensure labels are 0, 1, 2...\n", " self.bin_encoder_ = LabelEncoder()\n", " self.cat_encoder_ = LabelEncoder()\n", " # Fit them to the input data\n", " self.bin_encoder_.fit(y_bin)\n", " self.cat_encoder_.fit(y_cat)\n", " # Save the number of classes\n", " self.n_classes_ = [\n", " self.bin_encoder_.classes_.size,\n", " self.cat_encoder_.classes_.size,\n", " ]\n", " # Save number of expected outputs in the Keras model\n", " # SciKeras will automatically use this to do error-checking\n", " self.n_outputs_expected_ = 2\n", " return self\n", "\n", " def transform(self, y: np.ndarray) -> List[np.ndarray]:\n", " y_bin, y_cat = y[:, 0], y[:, 1]\n", " # Apply transformers to input array\n", " y_bin = self.bin_encoder_.transform(y_bin)\n", " y_cat = self.cat_encoder_.transform(y_cat)\n", " # Split the data into a list\n", " return [y_bin, y_cat]\n", "\n", " def inverse_transform(self, y: List[np.ndarray], return_proba: bool = False) -> np.ndarray:\n", " y_pred_proba = y # rename for clarity, what Keras gives us are probs\n", " if return_proba:\n", " return np.column_stack(y_pred_proba, axis=1)\n", " # Get class predictions from probabilities\n", " y_pred_bin = (y_pred_proba[0] > 0.5).astype(int).reshape(-1, )\n", " y_pred_cat = np.argmax(y_pred_proba[1], axis=1)\n", " # Pass back through LabelEncoder\n", " y_pred_bin = self.bin_encoder_.inverse_transform(y_pred_bin)\n", " y_pred_cat = self.cat_encoder_.inverse_transform(y_pred_cat)\n", " return np.column_stack([y_pred_bin, y_pred_cat])\n", " \n", " def get_metadata(self):\n", " return {\n", " \"n_classes_\": self.n_classes_,\n", " \"n_outputs_expected_\": self.n_outputs_expected_,\n", " }" ] }, { "cell_type": "markdown", "id": "processed-browse", "metadata": {}, "source": [ "Note that in addition to the usual `transform` and `inverse_transform` methods, we implement the `get_metadata` method to return the `n_classes_` attribute.\n", "\n", "Lets test our transformer with the same dataset we previously used to test our model:" ] }, { "cell_type": "code", "execution_count": 13, "id": "consolidated-chinese", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:03.518807Z", "iopub.status.busy": "2021-02-20T16:18:03.517705Z", "iopub.status.idle": "2021-02-20T16:18:03.525557Z", "shell.execute_reply": "2021-02-20T16:18:03.526329Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "`y`, as will be passed to Keras:\n", "[array([1, 0, 1, 0]), array([4, 2, 2, 1])]\n" ] } ], "source": [ "tf = MultiOutputTransformer()\n", "\n", "y_sklearn = np.column_stack(y)\n", "\n", "y_keras = tf.fit_transform(y_sklearn)\n", "print(\"`y`, as will be passed to Keras:\")\n", "print([y_keras[0][:4], y_keras[1][:4]])" ] }, { "cell_type": "code", "execution_count": 14, "id": "comparative-device", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:03.529751Z", "iopub.status.busy": "2021-02-20T16:18:03.528686Z", "iopub.status.idle": "2021-02-20T16:18:03.542611Z", "shell.execute_reply": "2021-02-20T16:18:03.543346Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "`y_pred`, as will be returned to sklearn:\n" ] }, { "data": { "text/plain": [ "array([[1, 2],\n", " [0, 1],\n", " [0, 3],\n", " [1, 1],\n", " [0, 2]])" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_pred_sklearn = tf.inverse_transform(y_pred)\n", "print(\"`y_pred`, as will be returned to sklearn:\")\n", "y_pred_sklearn[:5]" ] }, { "cell_type": "code", "execution_count": 15, "id": "pregnant-bibliography", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:03.546988Z", "iopub.status.busy": "2021-02-20T16:18:03.545916Z", "iopub.status.idle": "2021-02-20T16:18:03.551671Z", "shell.execute_reply": "2021-02-20T16:18:03.552404Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "metadata = {'n_classes_': [2, 5], 'n_outputs_expected_': 2}\n" ] } ], "source": [ "print(f\"metadata = {tf.get_metadata()}\")" ] }, { "cell_type": "markdown", "id": "generous-webcam", "metadata": {}, "source": [ "Since this looks good, we move on to integrating our transformer into our classifier." ] }, { "cell_type": "code", "execution_count": 16, "id": "living-headline", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:03.555994Z", "iopub.status.busy": "2021-02-20T16:18:03.554825Z", "iopub.status.idle": "2021-02-20T16:18:03.561864Z", "shell.execute_reply": "2021-02-20T16:18:03.562613Z" } }, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score\n", "\n", "\n", "class MultiOutputClassifier(KerasClassifier):\n", "\n", " @property\n", " def target_encoder(self):\n", " return MultiOutputTransformer()\n", " \n", " @staticmethod\n", " def scorer(y_true, y_pred, **kwargs):\n", " y_bin, y_cat = y_true[:, 0], y_true[:, 1]\n", " y_pred_bin, y_pred_cat = y_pred[:, 0], y_pred[:, 1]\n", " # Keras by default uses the mean of losses of each outputs, so here we do the same\n", " return np.mean([accuracy_score(y_bin, y_pred_bin), accuracy_score(y_cat, y_pred_cat)])" ] }, { "cell_type": "markdown", "id": "interstate-climb", "metadata": {}, "source": [ "### 2.3 Test classifier" ] }, { "cell_type": "code", "execution_count": 17, "id": "collaborative-majority", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:03.566976Z", "iopub.status.busy": "2021-02-20T16:18:03.565729Z", "iopub.status.idle": "2021-02-20T16:18:03.571109Z", "shell.execute_reply": "2021-02-20T16:18:03.574840Z" } }, "outputs": [], "source": [ "from sklearn.preprocessing import StandardScaler\n", "\n", "# Use labels as features, just to make sure we can learn correctly\n", "X = y_sklearn\n", "X = StandardScaler().fit_transform(X)" ] }, { "cell_type": "code", "execution_count": 18, "id": "surprising-elevation", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:03.578436Z", "iopub.status.busy": "2021-02-20T16:18:03.577346Z", "iopub.status.idle": "2021-02-20T16:18:04.386154Z", "shell.execute_reply": "2021-02-20T16:18:04.386973Z" } }, "outputs": [ { "data": { "text/plain": [ "0.365" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf = MultiOutputClassifier(model=get_clf_model, verbose=0, random_state=0)\n", "\n", "clf.fit(X, y_sklearn).score(X, y_sklearn)" ] }, { "cell_type": "markdown", "id": "therapeutic-mention", "metadata": {}, "source": [ "## 3. Multiple inputs\n", "\n", "The process for multiple inputs is similar, but instead of overriding the transformer in `target_encoder` we override `feature_encoder`.\n", "\n", "\n", "```python .noeval\n", "class MultiInputTransformer(BaseEstimator, TransformerMixin):\n", " ...\n", "\n", "class MultiInputClassifier(KerasClassifier):\n", " @property\n", " def feature_encoder(self):\n", " return MultiInputTransformer(...)\n", "```\n", "\n", "### 3.1 Define Keras Model\n", "\n", "Let's define a Keras **regression** Model with 2 inputs:" ] }, { "cell_type": "code", "execution_count": 19, "id": "cooperative-deadline", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:04.390844Z", "iopub.status.busy": "2021-02-20T16:18:04.389704Z", "iopub.status.idle": "2021-02-20T16:18:04.397036Z", "shell.execute_reply": "2021-02-20T16:18:04.397820Z" } }, "outputs": [], "source": [ "def get_reg_model():\n", "\n", " inp1 = keras.layers.Input(shape=(1, ))\n", " inp2 = keras.layers.Input(shape=(1, ))\n", "\n", " x1 = keras.layers.Dense(100, activation=\"relu\")(inp1)\n", " x2 = keras.layers.Dense(50, activation=\"relu\")(inp2)\n", "\n", " concat = keras.layers.Concatenate(axis=-1)([x1, x2])\n", "\n", " out = keras.layers.Dense(1)(concat)\n", "\n", " model = keras.Model(inputs=[inp1, inp2], outputs=out)\n", " model.compile(loss=\"mse\")\n", "\n", " return model" ] }, { "cell_type": "markdown", "id": "damaged-equivalent", "metadata": {}, "source": [ "And test it with a small mock dataset:" ] }, { "cell_type": "code", "execution_count": 20, "id": "mounted-peeing", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:04.401574Z", "iopub.status.busy": "2021-02-20T16:18:04.400479Z", "iopub.status.idle": "2021-02-20T16:18:04.974040Z", "shell.execute_reply": "2021-02-20T16:18:04.974853Z" } }, "outputs": [], "source": [ "X = np.random.random(size=(100, 2))\n", "y = np.sum(X, axis=1)\n", "X = np.split(X, 2, axis=1)\n", "\n", "model = get_reg_model()\n", "\n", "model.fit(X, y, verbose=0)\n", "y_pred = model.predict(X).squeeze()" ] }, { "cell_type": "code", "execution_count": 21, "id": "thorough-dietary", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:04.978920Z", "iopub.status.busy": "2021-02-20T16:18:04.977527Z", "iopub.status.idle": "2021-02-20T16:18:04.984764Z", "shell.execute_reply": "2021-02-20T16:18:04.985505Z" } }, "outputs": [ { "data": { "text/plain": [ "-5.3344010320539" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.metrics import r2_score\n", "\n", "r2_score(y, y_pred)" ] }, { "cell_type": "markdown", "id": "novel-disorder", "metadata": {}, "source": [ "Having verified that our model builds without errors and accepts the inputs types we expect, we move onto integrating a transformer into our SciKeras model.\n", "\n", "### 3.2 Define data transformer\n", "\n", "Just like for overriding `target_encoder`, we just need to define a sklearn transformer and drop it into our SciKeras wrapper. Since we hardcoded the input\n", "shapes into our model and do not rely on any transformer-generated metadata, we can simply use `sklearn.preprocessing.FunctionTransformer`:" ] }, { "cell_type": "code", "execution_count": 22, "id": "dominant-problem", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:04.989229Z", "iopub.status.busy": "2021-02-20T16:18:04.988153Z", "iopub.status.idle": "2021-02-20T16:18:04.993506Z", "shell.execute_reply": "2021-02-20T16:18:04.997518Z" } }, "outputs": [], "source": [ "from sklearn.preprocessing import FunctionTransformer\n", "\n", "\n", "class MultiInputRegressor(KerasRegressor):\n", "\n", " @property\n", " def feature_encoder(self):\n", " return FunctionTransformer(\n", " func=lambda X: [X[:, 0], X[:, 1]],\n", " )" ] }, { "cell_type": "markdown", "id": "muslim-comparison", "metadata": {}, "source": [ "Note that we did **not** implement `inverse_transform` (that is, we did not pass an `inverse_func` argument to `FunctionTransformer`) because features are never converted back to their original form.\n", "\n", "### 3.3 Test regressor" ] }, { "cell_type": "code", "execution_count": 23, "id": "unauthorized-marathon", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:05.003910Z", "iopub.status.busy": "2021-02-20T16:18:05.003368Z", "iopub.status.idle": "2021-02-20T16:18:05.584158Z", "shell.execute_reply": "2021-02-20T16:18:05.584969Z" } }, "outputs": [ { "data": { "text/plain": [ "-3.2624139786426314" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reg = MultiInputRegressor(model=get_reg_model, verbose=0, random_state=0)\n", "\n", "X_sklearn = np.column_stack(X)\n", "\n", "reg.fit(X_sklearn, y).score(X_sklearn, y)" ] }, { "cell_type": "markdown", "id": "geographic-myanmar", "metadata": {}, "source": [ "## 4. Multidimensional inputs with MNIST dataset\n", "\n", "In this example, we look at how we can use SciKeras to process the MNIST dataset. The dataset is composed of 60,000 images of digits, each of which is a 2D 28x28 image.\n", "\n", "The dataset and Keras Model architecture used come from a [Keras example](https://keras.io/examples/vision/mnist_convnet/). It may be beneficial to understand the Keras model by reviewing that example first." ] }, { "cell_type": "code", "execution_count": 24, "id": "insured-attitude", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:05.588984Z", "iopub.status.busy": "2021-02-20T16:18:05.587855Z", "iopub.status.idle": "2021-02-20T16:18:05.961133Z", "shell.execute_reply": "2021-02-20T16:18:05.961975Z" } }, "outputs": [ { "data": { "text/plain": [ "(60000, 28, 28)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", "x_train.shape" ] }, { "cell_type": "markdown", "id": "special-addition", "metadata": {}, "source": [ "The outputs (labels) are numbers 0-9:" ] }, { "cell_type": "code", "execution_count": 25, "id": "mathematical-cassette", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:05.966255Z", "iopub.status.busy": "2021-02-20T16:18:05.964996Z", "iopub.status.idle": "2021-02-20T16:18:05.973700Z", "shell.execute_reply": "2021-02-20T16:18:05.974479Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(60000,)\n", "[0 1 2 3 4 5 6 7 8 9]\n" ] } ], "source": [ "print(y_train.shape)\n", "print(np.unique(y_train))" ] }, { "cell_type": "markdown", "id": "faced-split", "metadata": {}, "source": [ "First, we will \"flatten\" the data into an array of shape `(n_samples, 28*28)` (i.e. a 2D array). This will allow us to use sklearn ecosystem utilities, for example, `sklearn.preprocessing.MinMaxScaler`." ] }, { "cell_type": "code", "execution_count": 26, "id": "prospective-instruction", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:05.978044Z", "iopub.status.busy": "2021-02-20T16:18:05.976941Z", "iopub.status.idle": "2021-02-20T16:18:06.726852Z", "shell.execute_reply": "2021-02-20T16:18:06.725561Z" } }, "outputs": [], "source": [ "from sklearn.preprocessing import MinMaxScaler\n", "\n", "n_samples_train = x_train.shape[0]\n", "n_samples_test = x_test.shape[0]\n", "\n", "x_train = x_train.reshape((n_samples_train, -1))\n", "x_test = x_test.reshape((n_samples_test, -1))\n", "x_train = MinMaxScaler().fit_transform(x_train)\n", "x_test = MinMaxScaler().fit_transform(x_test)\n", "\n", "# reduce dataset size for faster training\n", "n_samples = 1000\n", "x_train, y_train, x_test, y_test = x_train[:n_samples], y_train[:n_samples], x_test[:n_samples], y_test[:n_samples]" ] }, { "cell_type": "code", "execution_count": 27, "id": "artistic-sympathy", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:06.730919Z", "iopub.status.busy": "2021-02-20T16:18:06.729759Z", "iopub.status.idle": "2021-02-20T16:18:06.737761Z", "shell.execute_reply": "2021-02-20T16:18:06.738618Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(784,)\n" ] } ], "source": [ "print(x_train.shape[1:]) # 784 = 28*28" ] }, { "cell_type": "code", "execution_count": 28, "id": "helpful-andorra", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:06.743212Z", "iopub.status.busy": "2021-02-20T16:18:06.742673Z", "iopub.status.idle": "2021-02-20T16:18:06.749151Z", "shell.execute_reply": "2021-02-20T16:18:06.750558Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.0 1.0\n" ] } ], "source": [ "print(np.min(x_train), np.max(x_train)) # scaled 0-1" ] }, { "cell_type": "markdown", "id": "cleared-equality", "metadata": {}, "source": [ "Of course, in this case, we could have just as easily used numpy functions to scale our data, but we use `MinMaxScaler` to demonstrate use of the sklearn ecosystem.\n", "\n", "### 4.1 Define Keras Model\n", "\n", "Next we will define our Keras model (adapted from [keras.io](https://keras.io/examples/vision/mnist_convnet/)):" ] }, { "cell_type": "code", "execution_count": 29, "id": "eastern-frontier", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:06.758728Z", "iopub.status.busy": "2021-02-20T16:18:06.757765Z", "iopub.status.idle": "2021-02-20T16:18:06.765602Z", "shell.execute_reply": "2021-02-20T16:18:06.763987Z" } }, "outputs": [], "source": [ "num_classes = 10\n", "input_shape = (28, 28, 1)\n", "\n", "\n", "def get_model(meta):\n", " model = keras.Sequential(\n", " [\n", " keras.Input(input_shape),\n", " keras.layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n", " keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", " keras.layers.Conv2D(64, kernel_size=(3, 3), activation=\"relu\"),\n", " keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", " keras.layers.Flatten(),\n", " keras.layers.Dropout(0.5),\n", " keras.layers.Dense(num_classes, activation=\"softmax\"),\n", " ]\n", " )\n", " model.compile(\n", " loss=\"sparse_categorical_crossentropy\"\n", " )\n", " return model" ] }, { "cell_type": "markdown", "id": "experienced-remark", "metadata": {}, "source": [ "Now let's define a transformer that we will use to reshape our input from the sklearn shape (`(n_samples, 784)`) to the Keras shape (which we will be `(n_samples, 28, 28, 1)`)." ] }, { "cell_type": "code", "execution_count": 30, "id": "found-faith", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:06.772504Z", "iopub.status.busy": "2021-02-20T16:18:06.771177Z", "iopub.status.idle": "2021-02-20T16:18:06.773114Z", "shell.execute_reply": "2021-02-20T16:18:06.773553Z" } }, "outputs": [], "source": [ "class MultiDimensionalClassifier(KerasClassifier):\n", "\n", " @property\n", " def feature_encoder(self):\n", " return FunctionTransformer(\n", " func=lambda X: X.reshape(X.shape[0], *input_shape),\n", " )" ] }, { "cell_type": "code", "execution_count": 31, "id": "dominant-refrigerator", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:06.781274Z", "iopub.status.busy": "2021-02-20T16:18:06.778500Z", "iopub.status.idle": "2021-02-20T16:18:06.789592Z", "shell.execute_reply": "2021-02-20T16:18:06.786785Z" } }, "outputs": [], "source": [ "clf = MultiDimensionalClassifier(\n", " model=get_model,\n", " epochs=10,\n", " batch_size=128,\n", " validation_split=0.1,\n", " random_state=0,\n", ")" ] }, { "cell_type": "markdown", "id": "champion-waterproof", "metadata": {}, "source": [ "### 4.2 Test\n", "\n", "Train and score the model (this takes some time)" ] }, { "cell_type": "code", "execution_count": 32, "id": "leading-growing", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:06.795651Z", "iopub.status.busy": "2021-02-20T16:18:06.795124Z", "iopub.status.idle": "2021-02-20T16:18:16.405081Z", "shell.execute_reply": "2021-02-20T16:18:16.406126Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/8 [==>...........................] - ETA: 6s - loss: 2.3388" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "2/8 [======>.......................] - ETA: 0s - loss: 2.3138" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "3/8 [==========>...................] - ETA: 0s - loss: 2.2955" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/8 [==============>...............] - ETA: 0s - loss: 2.2804" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "5/8 [=================>............] - ETA: 0s - loss: 2.2653" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "6/8 [=====================>........] - ETA: 0s - loss: 2.2506" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "7/8 [=========================>....] - ETA: 0s - loss: 2.2346" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "8/8 [==============================] - 2s 205ms/step - loss: 2.2129 - val_loss: 1.9378\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2/10\n", "\r", "1/8 [==>...........................] - ETA: 0s - loss: 1.9194" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "2/8 [======>.......................] - ETA: 0s - loss: 1.8859" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "3/8 [==========>...................] - ETA: 0s - loss: 1.8602" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/8 [==============>...............] - ETA: 0s - loss: 1.8336" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "5/8 [=================>............] - ETA: 0s - loss: 1.8105" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "6/8 [=====================>........] - ETA: 0s - loss: 1.7871" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "7/8 [=========================>....] - ETA: 0s - loss: 1.7651" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "8/8 [==============================] - 1s 99ms/step - loss: 1.7355 - val_loss: 1.4332\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 3/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/8 [==>...........................] - ETA: 0s - loss: 1.5482" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "2/8 [======>.......................] - ETA: 0s - loss: 1.4905" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "3/8 [==========>...................] - ETA: 0s - loss: 1.4373" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/8 [==============>...............] - ETA: 0s - loss: 1.4038" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "5/8 [=================>............] - ETA: 0s - loss: 1.3722" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "6/8 [=====================>........] - ETA: 0s - loss: 1.3461" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "7/8 [=========================>....] - ETA: 0s - loss: 1.3252" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "8/8 [==============================] - 1s 101ms/step - loss: 1.2980 - val_loss: 1.1264\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/8 [==>...........................] - ETA: 0s - loss: 0.9745" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "2/8 [======>.......................] - ETA: 0s - loss: 0.9901" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "3/8 [==========>...................] - ETA: 0s - loss: 0.9931" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/8 [==============>...............] - ETA: 0s - loss: 0.9934" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "5/8 [=================>............] - ETA: 0s - loss: 0.9891" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "6/8 [=====================>........] - ETA: 0s - loss: 0.9830" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "7/8 [=========================>....] - ETA: 0s - loss: 0.9744" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "8/8 [==============================] - 1s 100ms/step - loss: 0.9632 - val_loss: 0.8725\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/8 [==>...........................] - ETA: 0s - loss: 0.8483" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "2/8 [======>.......................] - ETA: 0s - loss: 0.8333" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "3/8 [==========>...................] - ETA: 0s - loss: 0.8148" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/8 [==============>...............] - ETA: 0s - loss: 0.8013" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "5/8 [=================>............] - ETA: 0s - loss: 0.7932" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "6/8 [=====================>........] - ETA: 0s - loss: 0.7884" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "7/8 [=========================>....] - ETA: 0s - loss: 0.7839" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "8/8 [==============================] - 1s 96ms/step - loss: 0.7774 - val_loss: 0.7834\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 6/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/8 [==>...........................] - ETA: 0s - loss: 0.7395" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "2/8 [======>.......................] - ETA: 0s - loss: 0.7439" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "3/8 [==========>...................] - ETA: 0s - loss: 0.7406" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/8 [==============>...............] - ETA: 0s - loss: 0.7288" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "5/8 [=================>............] - ETA: 0s - loss: 0.7162" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "6/8 [=====================>........] - ETA: 0s - loss: 0.7084" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "7/8 [=========================>....] - ETA: 0s - loss: 0.7005" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "8/8 [==============================] - 1s 100ms/step - loss: 0.6900 - val_loss: 0.8330\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 7/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/8 [==>...........................] - ETA: 0s - loss: 0.6743" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "2/8 [======>.......................] - ETA: 0s - loss: 0.6714" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "3/8 [==========>...................] - ETA: 0s - loss: 0.6668" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/8 [==============>...............] - ETA: 0s - loss: 0.6620" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "5/8 [=================>............] - ETA: 0s - loss: 0.6558" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "6/8 [=====================>........] - ETA: 0s - loss: 0.6494" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "7/8 [=========================>....] - ETA: 0s - loss: 0.6429" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "8/8 [==============================] - 1s 89ms/step - loss: 0.6336 - val_loss: 0.6026\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 8/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/8 [==>...........................] - ETA: 0s - loss: 0.5816" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "2/8 [======>.......................] - ETA: 0s - loss: 0.5781" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "3/8 [==========>...................] - ETA: 0s - loss: 0.5809" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/8 [==============>...............] - ETA: 0s - loss: 0.5764" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "5/8 [=================>............] - ETA: 0s - loss: 0.5720" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "6/8 [=====================>........] - ETA: 0s - loss: 0.5677" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "7/8 [=========================>....] - ETA: 0s - loss: 0.5666" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "8/8 [==============================] - 1s 87ms/step - loss: 0.5648 - val_loss: 0.7786\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 9/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/8 [==>...........................] - ETA: 0s - loss: 0.5750" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "2/8 [======>.......................] - ETA: 0s - loss: 0.6021" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "3/8 [==========>...................] - ETA: 0s - loss: 0.5836" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/8 [==============>...............] - ETA: 0s - loss: 0.5673" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "5/8 [=================>............] - ETA: 0s - loss: 0.5558" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "6/8 [=====================>........] - ETA: 0s - loss: 0.5485" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "7/8 [=========================>....] - ETA: 0s - loss: 0.5435" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "8/8 [==============================] - 1s 101ms/step - loss: 0.5365 - val_loss: 0.5447\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 10/10\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/8 [==>...........................] - ETA: 0s - loss: 0.4804" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "2/8 [======>.......................] - ETA: 0s - loss: 0.5018" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "3/8 [==========>...................] - ETA: 0s - loss: 0.4980" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/8 [==============>...............] - ETA: 0s - loss: 0.4908" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "5/8 [=================>............] - ETA: 0s - loss: 0.4890" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "6/8 [=====================>........] - ETA: 0s - loss: 0.4859" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "7/8 [=========================>....] - ETA: 0s - loss: 0.4812" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "8/8 [==============================] - 1s 98ms/step - loss: 0.4750 - val_loss: 0.6588\n" ] } ], "source": [ "_ = clf.fit(x_train, y_train)" ] }, { "cell_type": "code", "execution_count": 33, "id": "impressed-deposit", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:16.411148Z", "iopub.status.busy": "2021-02-20T16:18:16.409868Z", "iopub.status.idle": "2021-02-20T16:18:16.690794Z", "shell.execute_reply": "2021-02-20T16:18:16.690315Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/8 [==>...........................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "4/8 [==============>...............] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "7/8 [=========================>....] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "8/8 [==============================] - 0s 23ms/step\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Test score (accuracy): 0.76\n" ] } ], "source": [ "score = clf.score(x_test, y_test)\n", "print(f\"Test score (accuracy): {score:.2f}\")" ] }, { "cell_type": "markdown", "id": "geological-international", "metadata": {}, "source": [ "## 5. Ragged datasets with tf.data.Dataset\n", "\n", "SciKeras provides a third dependency injection point that operates on the entire dataset: X, y & sample_weight.\n", "This `dataset_transformer` is applied after `target_transformer` and `feature_transformer`.\n", "One use case for this dependency injection point is to transform data from tabular/array-like to the `tf.data.Dataset` format, which only requires iteration.\n", "We can use this to create a `tf.data.Dataset` of ragged tensors.\n", "\n", "Note that `dataset_transformer` should accept a single single dictionary as its argument to `transform` and `fit`, and return a single dictionary as well.\n", "More details on this are in the [docs](https://www.adriangb.com/scikeras/refs/heads/master/advanced.html#data-transformers).\n", "\n", "Let's start by defining our data. We'll have an extra \"feature\" that marks the observation index, but we'll remove it when we deconstruct our data in the transformer." ] }, { "cell_type": "code", "execution_count": 34, "id": "ordinary-information", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:16.693562Z", "iopub.status.busy": "2021-02-20T16:18:16.693000Z", "iopub.status.idle": "2021-02-20T16:18:16.698702Z", "shell.execute_reply": "2021-02-20T16:18:16.699136Z" } }, "outputs": [], "source": [ "feature_1 = np.random.uniform(size=(10, ))\n", "feature_2 = np.random.uniform(size=(10, ))\n", "obs = [0, 0, 0, 1, 1, 2, 3, 3, 4, 4]\n", "\n", "X = np.column_stack([feature_1, feature_2, obs]).astype(\"float32\")\n", "\n", "y = np.array([\"class1\"] * 5 + [\"class2\"] * 5, dtype=str)" ] }, { "cell_type": "markdown", "id": "personal-delicious", "metadata": {}, "source": [ "Next, we define our `dataset_transformer`. We will do this by defining a custom forward transformation outside of the Keras model. Note that we do not define an inverse transformation since that is never used.\n", "Also note that `dataset_transformer` will _always_ be called with `X` (i.e. the first element of the tuple will always be populated), but will be called with `y=None` when used for `predict`. Thus,\n", "you should check if `y` and `sample_weigh` are None before doing any operations on them." ] }, { "cell_type": "code", "execution_count": 35, "id": "geological-rochester", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:16.703775Z", "iopub.status.busy": "2021-02-20T16:18:16.702248Z", "iopub.status.idle": "2021-02-20T16:18:16.721197Z", "shell.execute_reply": "2021-02-20T16:18:16.722454Z" } }, "outputs": [], "source": [ "from typing import Dict, Any\n", "\n", "import tensorflow as tf\n", "\n", "\n", "def ragged_transformer(data: Dict[str, Any]) -> Dict[str, Any]:\n", " x, y, sample_weight = data[\"x\"], data.get(\"y\", None), data.get(\"sample_weight\", None)\n", " if y is not None:\n", " y = y.reshape(-1, 1 if len(y.shape) == 1 else y.shape[1])\n", " y = y[tf.RaggedTensor.from_value_rowids(y, x[:, -1]).row_starts().numpy()]\n", " if sample_weight is not None:\n", " sample_weight = sample_weight.reshape(-1, 1 if len(sample_weight.shape) == 1 else sample_weight.shape[1])\n", " sample_weight = sample_weight[tf.RaggedTensor.from_value_rowids(sample_weight, x[:, -1]).row_starts().numpy()]\n", " x = tf.RaggedTensor.from_value_rowids(x[:, :-1], x[:, -1])\n", " data[\"x\"] = x\n", " if \"y\" in data:\n", " data[\"y\"] = y\n", " if \"sample_weight\" in data:\n", " data[\"sample_weight\"] = sample_weight\n", " return data" ] }, { "cell_type": "markdown", "id": "urban-contemporary", "metadata": {}, "source": [ "In this case, we chose to keep `y` and `sample_weight` as numpy arrays, which will allow us to re-use ClassWeightDataTransformer,\n", "the default `dataset_transformer` for `KerasClassifier`.\n", "\n", "Lets quickly test our transformer:" ] }, { "cell_type": "code", "execution_count": 36, "id": "electrical-inside", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:16.727291Z", "iopub.status.busy": "2021-02-20T16:18:16.725578Z", "iopub.status.idle": "2021-02-20T16:18:16.755504Z", "shell.execute_reply": "2021-02-20T16:18:16.756447Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "(5, None, 2)\n" ] } ], "source": [ "data = ragged_transformer(dict(x=X, y=y, sample_weight=None))\n", "print(type(data[\"x\"]))\n", "print(data[\"x\"].shape)" ] }, { "cell_type": "markdown", "id": "unauthorized-organic", "metadata": {}, "source": [ "And the `y=None` case:" ] }, { "cell_type": "code", "execution_count": 37, "id": "hairy-glass", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:16.761862Z", "iopub.status.busy": "2021-02-20T16:18:16.760636Z", "iopub.status.idle": "2021-02-20T16:18:16.778784Z", "shell.execute_reply": "2021-02-20T16:18:16.779879Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "(5, None, 2)\n" ] } ], "source": [ "data = ragged_transformer(dict(x=X, y=None, sample_weight=None))\n", "print(type(data[\"x\"]))\n", "print(data[\"x\"].shape)" ] }, { "cell_type": "markdown", "id": "protecting-range", "metadata": {}, "source": [ "Everything looks good!\n", "\n", "Because Keras will not accept a RaggedTensor directly, we will need to wrap our entire dataset into a tensorflow `Dataset`. We can do this by adding one more transformation step:\n", "\n", "Next, we can add our transormers to our model. We use an sklearn `Pipeline` (generated via `make_pipeline`) to keep ClassWeightDataTransformer operational while implementing our custom transformation." ] }, { "cell_type": "code", "execution_count": 38, "id": "divided-garbage", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:16.783941Z", "iopub.status.busy": "2021-02-20T16:18:16.782759Z", "iopub.status.idle": "2021-02-20T16:18:16.795179Z", "shell.execute_reply": "2021-02-20T16:18:16.794168Z" } }, "outputs": [], "source": [ "def dataset_transformer(data: Dict[str, Any]) -> Dict[str, Any]:\n", " x_y_s = data[\"x\"], data.get(\"y\", None), data.get(\"sample_weight\", None)\n", " data[\"x\"] = tf.data.Dataset.from_tensor_slices(x_y_s)\n", " # don't blindly assign y & sw; if being called from\n", " # predict they should not just be None, they should not be present at all!\n", " if \"y\" in data:\n", " data[\"y\"] = None\n", " if \"sample_weight\" in data:\n", " data[\"sample_weight\"] = None\n", " return data" ] }, { "cell_type": "code", "execution_count": 39, "id": "demonstrated-finger", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:16.803796Z", "iopub.status.busy": "2021-02-20T16:18:16.803242Z", "iopub.status.idle": "2021-02-20T16:18:16.808543Z", "shell.execute_reply": "2021-02-20T16:18:16.809291Z" } }, "outputs": [], "source": [ "from sklearn.preprocessing import FunctionTransformer\n", "from sklearn.pipeline import make_pipeline\n", "\n", "\n", "class RaggedClassifier(KerasClassifier):\n", "\n", " @property\n", " def dataset_transformer(self):\n", " t1 = FunctionTransformer(ragged_transformer)\n", " t2 = super().dataset_transformer # ClassWeightDataTransformer\n", " t3 = FunctionTransformer(dataset_transformer)\n", " t4 = \"passthrough\" # see https://scikit-learn.org/stable/modules/compose.html#pipeline-chaining-estimators\n", " return make_pipeline(t1, t2, t3, t4)" ] }, { "cell_type": "markdown", "id": "pacific-landing", "metadata": {}, "source": [ "Now we can define a Model. We need some way to handle/flatten our ragged arrays within our model. For this example, we use a custom mean layer, but you could use an Embedding layer, LSTM, etc." ] }, { "cell_type": "code", "execution_count": 40, "id": "martial-argentina", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:16.818134Z", "iopub.status.busy": "2021-02-20T16:18:16.817553Z", "iopub.status.idle": "2021-02-20T16:18:16.828981Z", "shell.execute_reply": "2021-02-20T16:18:16.827085Z" } }, "outputs": [], "source": [ "from tensorflow import reduce_mean, reshape\n", "from tensorflow.keras import Sequential, layers\n", "\n", "\n", "class CustomMean(layers.Layer):\n", "\n", " def __init__(self, axis=None):\n", " super(CustomMean, self).__init__()\n", " self._supports_ragged_inputs = True\n", " self.axis = axis\n", "\n", " def call(self, inputs, **kwargs):\n", " input_shape = inputs.get_shape()\n", " return reshape(reduce_mean(inputs, axis=self.axis), (1, *input_shape[1:]))\n", "\n", "\n", "def get_model(meta):\n", " inp_shape = meta[\"X_shape_\"][1]-1\n", " model = Sequential([ \n", " layers.Input(shape=(inp_shape,), ragged=True),\n", " CustomMean(axis=0),\n", " layers.Dense(1, activation='sigmoid')\n", " ])\n", " return model" ] }, { "cell_type": "markdown", "id": "electric-boulder", "metadata": {}, "source": [ "And attach our model to our classifier wrapper:" ] }, { "cell_type": "code", "execution_count": 41, "id": "affecting-horror", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:16.837696Z", "iopub.status.busy": "2021-02-20T16:18:16.837146Z", "iopub.status.idle": "2021-02-20T16:18:16.841780Z", "shell.execute_reply": "2021-02-20T16:18:16.841286Z" } }, "outputs": [], "source": [ "clf = RaggedClassifier(get_model, loss=\"bce\")" ] }, { "cell_type": "markdown", "id": "grand-pitch", "metadata": {}, "source": [ "Finally, let's train and predict:" ] }, { "cell_type": "code", "execution_count": 42, "id": "extraordinary-ethiopia", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:16.846747Z", "iopub.status.busy": "2021-02-20T16:18:16.845614Z", "iopub.status.idle": "2021-02-20T16:18:17.475364Z", "shell.execute_reply": "2021-02-20T16:18:17.476056Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/5 [=====>........................] - ETA: 1s - loss: 0.6282" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "5/5 [==============================] - 0s 3ms/step - loss: 0.6816\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/5 [=====>........................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "5/5 [==============================] - 0s 2ms/step\n" ] }, { "data": { "text/plain": [ "array(['class1', 'class1', 'class1', 'class1', 'class1'], dtype='........................] - ETA: 1s - loss: 0.8259" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "5/5 [==============================] - 0s 1ms/step - loss: 0.7404\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "1/5 [=====>........................] - ETA: 0s" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "5/5 [==============================] - 0s 2ms/step\n" ] }, { "data": { "text/plain": [ "array(['class2', 'class2', 'class2', 'class2', 'class2'], dtype=' \"DatasetTransformer\":\n", " return self\n", "\n", " def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:\n", " class_weight = data.get(\"class_weight\", None)\n", " if class_weight is None:\n", " return data\n", " if isinstance(class_weight, str): # handle \"balanced\"\n", " class_weight_ = class_weight\n", " class_weight = defaultdict(lambda: class_weight_)\n", " y, sample_weight = data.get(\"y\", None), data.get(\"sample_weight\", None)\n", " assert sample_weight is None, \"Cannot use class_weight & sample_weight together\"\n", " if y is not None:\n", " # y should be a list of arrays, as split up by MultiOutputTransformer\n", " sample_weight = {\n", " output_name: compute_sample_weight(class_weight[output_num], output_data)\n", " for output_num, (output_name, output_data) in enumerate(zip(self.output_names, y))\n", " }\n", " # Note: class_weight is expected to be indexable by output_number in sklearn\n", " # see https://scikit-learn.org/stable/modules/generated/sklearn.utils.class_weight.compute_sample_weight.html\n", " # It is trivial to change the expected format to match Keras' ({output_name: weights, ...})\n", " # see https://github.com/keras-team/keras/issues/4735#issuecomment-267473722\n", " data[\"sample_weight\"] = sample_weight\n", " data[\"class_weight\"] = None\n", " return data\n", "\n", "\n", "def get_model(meta, compile_kwargs):\n", " inp = keras.layers.Input(shape=(meta[\"n_features_in_\"]))\n", " x1 = keras.layers.Dense(100, activation=\"relu\")(inp)\n", " out_bin = keras.layers.Dense(1, activation=\"sigmoid\")(x1)\n", " out_cat = keras.layers.Dense(meta[\"n_classes_\"][1], activation=\"softmax\")(x1)\n", " model = keras.Model(inputs=inp, outputs=[out_bin, out_cat])\n", " model.compile(\n", " loss=[\"binary_crossentropy\", \"sparse_categorical_crossentropy\"],\n", " optimizer=compile_kwargs[\"optimizer\"]\n", " )\n", " return model\n", "\n", "\n", "class CustomClassifier(KerasClassifier):\n", "\n", " @property\n", " def target_encoder(self):\n", " return MultiOutputTransformer()\n", " \n", " @property\n", " def dataset_transformer(self):\n", " return DatasetTransformer(\n", " output_names=self.model_.output_names,\n", " )" ] }, { "cell_type": "markdown", "id": "pending-comfort", "metadata": {}, "source": [ "Next, we define the data. We'll use `sklearn.datasets.make_blobs` to generate a relatively noisy dataset:" ] }, { "cell_type": "code", "execution_count": 46, "id": "spatial-wagon", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:18.041166Z", "iopub.status.busy": "2021-02-20T16:18:18.040543Z", "iopub.status.idle": "2021-02-20T16:18:18.087386Z", "shell.execute_reply": "2021-02-20T16:18:18.088022Z" } }, "outputs": [], "source": [ "from sklearn.datasets import make_blobs\n", "\n", "\n", "X, y = make_blobs(centers=3, random_state=0, cluster_std=20)\n", "# make a binary target for \"is the value of the first class?\"\n", "y_bin = y == y[0]\n", "y = np.column_stack([y_bin, y])" ] }, { "cell_type": "markdown", "id": "medical-louisville", "metadata": {}, "source": [ "Test the model without specifying class weighting:" ] }, { "cell_type": "code", "execution_count": 47, "id": "handmade-integration", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:18.090864Z", "iopub.status.busy": "2021-02-20T16:18:18.090315Z", "iopub.status.idle": "2021-02-20T16:18:20.003170Z", "shell.execute_reply": "2021-02-20T16:18:20.003910Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[91 9]\n", "[28 30 42]\n" ] } ], "source": [ "clf = CustomClassifier(get_model, epochs=100, verbose=0, random_state=0)\n", "clf.fit(X, y)\n", "y_pred = clf.predict(X)\n", "(_, counts_bin) = np.unique(y_pred[:, 0], return_counts=True)\n", "print(counts_bin)\n", "(_, counts_cat) = np.unique(y_pred[:, 1], return_counts=True)\n", "print(counts_cat)" ] }, { "cell_type": "markdown", "id": "hollywood-monday", "metadata": {}, "source": [ "As you can see, without `class_weight=\"balanced\"`, our classifier only predicts mainly a single class for the first output. Now with `class_weight=\"balanced\"`:" ] }, { "cell_type": "code", "execution_count": 48, "id": "seventh-belief", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:20.009572Z", "iopub.status.busy": "2021-02-20T16:18:20.008041Z", "iopub.status.idle": "2021-02-20T16:18:21.970004Z", "shell.execute_reply": "2021-02-20T16:18:21.970776Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[57 43]\n", "[27 27 46]\n" ] } ], "source": [ "clf = CustomClassifier(get_model, class_weight=\"balanced\", epochs=100, verbose=0, random_state=0)\n", "clf.fit(X, y)\n", "y_pred = clf.predict(X)\n", "(_, counts_bin) = np.unique(y_pred[:, 0], return_counts=True)\n", "print(counts_bin)\n", "(_, counts_cat) = np.unique(y_pred[:, 1], return_counts=True)\n", "print(counts_cat)" ] }, { "cell_type": "markdown", "id": "strong-sociology", "metadata": {}, "source": [ "Now, we get (mostly) balanced classes. But what if we want to specify our classes manually? You will notice that in when we defined `DatasetTransformer`, we gave it the ability to handle\n", "a list of class weights. For demonstration purposes, we will highly bias towards the second class in each output:" ] }, { "cell_type": "code", "execution_count": 49, "id": "killing-consequence", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:21.974452Z", "iopub.status.busy": "2021-02-20T16:18:21.973333Z", "iopub.status.idle": "2021-02-20T16:18:23.805388Z", "shell.execute_reply": "2021-02-20T16:18:23.804655Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 7 93]\n", "[ 2 98]\n" ] } ], "source": [ "clf = CustomClassifier(get_model, class_weight=[{0: 0.1, 1: 1}, {0: 0.1, 1: 1, 2: 0.1}], epochs=100, verbose=0, random_state=0)\n", "clf.fit(X, y)\n", "y_pred = clf.predict(X)\n", "(_, counts_bin) = np.unique(y_pred[:, 0], return_counts=True)\n", "print(counts_bin)\n", "(_, counts_cat) = np.unique(y_pred[:, 1], return_counts=True)\n", "print(counts_cat)" ] }, { "cell_type": "markdown", "id": "improving-venture", "metadata": {}, "source": [ "Or mixing the two methods, because our first output is unbalanced but our second is (presumably) balanced:" ] }, { "cell_type": "code", "execution_count": 50, "id": "gentle-suggestion", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:23.816262Z", "iopub.status.busy": "2021-02-20T16:18:23.814551Z", "iopub.status.idle": "2021-02-20T16:18:25.089362Z", "shell.execute_reply": "2021-02-20T16:18:25.089862Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[57 43]\n", "[30 25 45]\n" ] } ], "source": [ "clf = CustomClassifier(get_model, class_weight=[\"balanced\", None], epochs=100, verbose=0, random_state=0)\n", "clf.fit(X, y)\n", "y_pred = clf.predict(X)\n", "(_, counts_bin) = np.unique(y_pred[:, 0], return_counts=True)\n", "print(counts_bin)\n", "(_, counts_cat) = np.unique(y_pred[:, 1], return_counts=True)\n", "print(counts_cat)" ] }, { "cell_type": "markdown", "id": "united-melissa", "metadata": {}, "source": [ "## 7. Custom validation dataset\n", "\n", "Although `dataset_transformer` is primarily designed for data transformations, because it returns valid `**kwargs` to fit it can be used for other advanced use cases.\n", "In this example, we use `dataset_transformer` to implement a custom test/train split for Keras' internal validation. We'll use sklearn's\n", "`train_test_split`, but this could be implemented via an arbitrary user function, eg. to ensure balanced class distribution." ] }, { "cell_type": "code", "execution_count": 51, "id": "upper-times", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:25.095957Z", "iopub.status.busy": "2021-02-20T16:18:25.095388Z", "iopub.status.idle": "2021-02-20T16:18:25.108742Z", "shell.execute_reply": "2021-02-20T16:18:25.109206Z" } }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "\n", "def get_clf(meta: Dict[str, Any]):\n", " inp = keras.layers.Input(shape=(meta[\"n_features_in_\"],))\n", " x1 = keras.layers.Dense(100, activation=\"relu\")(inp)\n", " out = keras.layers.Dense(1, activation=\"sigmoid\")(x1)\n", " return keras.Model(inputs=inp, outputs=out)\n", "\n", "\n", "class CustomSplit(BaseEstimator, TransformerMixin):\n", "\n", " def __init__(self, test_size: float):\n", " self.test_size = test_size\n", " \n", " def fit(self, data: Dict[str, Any]) -> \"CustomSplit\":\n", " return self\n", "\n", " def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:\n", " if self.test_size == 0:\n", " return data\n", " x, y, sw = data[\"x\"], data.get(\"y\", None), data.get(\"sample_weight\", None)\n", " if y is None:\n", " return data\n", " if sw is None:\n", " x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=self.test_size, stratify=y)\n", " validation_data = (x_val, y_val)\n", " sw_train = None\n", " else:\n", " x_train, x_val, y_train, y_val, sw_train, sw_val = train_test_split(x, y, sw, test_size=self.test_size, stratify=y)\n", " validation_data = (x_val, y_val, sw_val)\n", " data[\"validation_data\"] = validation_data\n", " data[\"x\"], data[\"y\"], data[\"sample_weight\"] = x_train, y_train, sw_train\n", " return data\n", "\n", "\n", "class CustomClassifier(KerasClassifier):\n", "\n", " @property\n", " def dataset_transformer(self):\n", " return CustomSplit(test_size=self.validation_split)" ] }, { "cell_type": "markdown", "id": "mental-mozambique", "metadata": {}, "source": [ "And now lets test with a toy dataset. We specifically choose to make the target strings to show\n", "that with this approach, we can preserve all of the nice data pre-processing that SciKeras does\n", "for us, while still being able to split the final data before passing it to Keras." ] }, { "cell_type": "code", "execution_count": 52, "id": "burning-bookmark", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:25.111885Z", "iopub.status.busy": "2021-02-20T16:18:25.111351Z", "iopub.status.idle": "2021-02-20T16:18:25.115853Z", "shell.execute_reply": "2021-02-20T16:18:25.116274Z" } }, "outputs": [], "source": [ "y = np.array([\"a\"] * 900 + [\"b\"] * 100)\n", "X = np.array([0] * 900 + [1] * 100).reshape(-1, 1)" ] }, { "cell_type": "markdown", "id": "devoted-dominant", "metadata": {}, "source": [ "To get a base measurment to compare against, we'll run first with KerasClassifier as a benchmark." ] }, { "cell_type": "code", "execution_count": 53, "id": "center-liquid", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:25.119656Z", "iopub.status.busy": "2021-02-20T16:18:25.118503Z", "iopub.status.idle": "2021-02-20T16:18:26.530645Z", "shell.execute_reply": "2021-02-20T16:18:26.530146Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "binary_accuracy = 1.0\n", "val_binary_accuracy = 0.0\n" ] } ], "source": [ "clf = KerasClassifier(\n", " get_clf,\n", " loss=\"bce\",\n", " metrics=[\"binary_accuracy\"],\n", " verbose=False,\n", " validation_split=0.1,\n", " shuffle=False,\n", " random_state=0,\n", " epochs=10\n", ")\n", "\n", "clf.fit(X, y)\n", "print(f\"binary_accuracy = {clf.history_['binary_accuracy'][-1]}\")\n", "print(f\"val_binary_accuracy = {clf.history_['val_binary_accuracy'][-1]}\")" ] }, { "cell_type": "markdown", "id": "driving-hebrew", "metadata": {}, "source": [ "We see that we get near zero validation accuracy. Because one of our classes was only found in the tail end of our dataset and we specified `validation_split=0.1`, we validated with a class we had never seen before.\n", "\n", "We could specify `shuffle=True` (this is actually the default), but for highly imbalanced classes, this may not be as good as stratified splitting.\n", "\n", "So lets test our new `CustomClassifier`." ] }, { "cell_type": "code", "execution_count": 54, "id": "studied-content", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:26.538037Z", "iopub.status.busy": "2021-02-20T16:18:26.535340Z", "iopub.status.idle": "2021-02-20T16:18:27.637081Z", "shell.execute_reply": "2021-02-20T16:18:27.637908Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "binary_accuracy = 1.0\n", "val_binary_accuracy = 1.0\n" ] } ], "source": [ "clf = CustomClassifier(\n", " get_clf,\n", " loss=\"bce\",\n", " metrics=[\"binary_accuracy\"],\n", " verbose=False,\n", " validation_split=0.1,\n", " shuffle=False,\n", " random_state=0,\n", " epochs=10\n", ")\n", "\n", "clf.fit(X, y)\n", "print(f\"binary_accuracy = {clf.history_['binary_accuracy'][-1]}\")\n", "print(f\"val_binary_accuracy = {clf.history_['val_binary_accuracy'][-1]}\")" ] }, { "cell_type": "markdown", "id": "lovely-assistant", "metadata": {}, "source": [ "Much better!" ] }, { "cell_type": "markdown", "id": "polyphonic-sweet", "metadata": {}, "source": [ "## 8. Dynamically setting batch_size" ] }, { "cell_type": "markdown", "id": "gentle-blake", "metadata": {}, "source": [ "In this tutorial, we use the `data_transformer` interface to implement a dynamic batch_size, similar to sklearn's [MLPClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html). We will implement `batch_size` as `batch_size=min(200, n_samples)`." ] }, { "cell_type": "code", "execution_count": 55, "id": "interracial-spouse", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:27.641733Z", "iopub.status.busy": "2021-02-20T16:18:27.640598Z", "iopub.status.idle": "2021-02-20T16:18:27.650954Z", "shell.execute_reply": "2021-02-20T16:18:27.651725Z" } }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "\n", "def check_batch_size(x):\n", " \"\"\"Check the batch_size used in training.\n", " \"\"\"\n", " bs = x.shape[0]\n", " if bs is not None:\n", " print(f\"batch_size={bs}\")\n", " return x\n", "\n", "\n", "def get_clf(meta: Dict[str, Any]):\n", " inp = keras.layers.Input(shape=(meta[\"n_features_in_\"],))\n", " x1 = keras.layers.Dense(100, activation=\"relu\")(inp)\n", " x2 = keras.layers.Lambda(check_batch_size)(x1)\n", " out = keras.layers.Dense(1, activation=\"sigmoid\")(x2)\n", " return keras.Model(inputs=inp, outputs=out)\n", "\n", "\n", "class DynamicBatch(BaseEstimator, TransformerMixin):\n", "\n", " def fit(self, data: Dict[str, Any]) -> \"DynamicBatch\":\n", " return self\n", "\n", " def transform(self, data: Dict[str, Any]) -> Dict[str, Any]:\n", " n_samples = data[\"x\"].shape[0]\n", " data[\"batch_size\"] = min(200, n_samples)\n", " return data\n", "\n", "\n", "class DynamicBatchClassifier(KerasClassifier):\n", "\n", " @property\n", " def dataset_transformer(self):\n", " return DynamicBatch()" ] }, { "cell_type": "markdown", "id": "portuguese-warning", "metadata": {}, "source": [ "Since this is happening inside SciKeras, this will work even if we are doing cross validation (which adjusts the split according to `cv`)." ] }, { "cell_type": "code", "execution_count": 56, "id": "bottom-citation", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:27.655319Z", "iopub.status.busy": "2021-02-20T16:18:27.654236Z", "iopub.status.idle": "2021-02-20T16:18:30.696383Z", "shell.execute_reply": "2021-02-20T16:18:30.695887Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "batch_size=167\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "batch_size=167\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "batch_size=167\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "batch_size=167\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "batch_size=166\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "batch_size=166\n" ] } ], "source": [ "from sklearn.model_selection import cross_val_score\n", "\n", "clf = DynamicBatchClassifier(\n", " get_clf,\n", " loss=\"bce\",\n", " verbose=False,\n", " random_state=0\n", ")\n", "\n", "_ = cross_val_score(clf, X, y, cv=6) # note: 1000 / 6 = 167" ] }, { "cell_type": "markdown", "id": "broke-drill", "metadata": {}, "source": [ "But if we train with larger inputs, we can hit the cap of 200 we set:" ] }, { "cell_type": "code", "execution_count": 57, "id": "bibliographic-somerset", "metadata": { "execution": { "iopub.execute_input": "2021-02-20T16:18:30.701194Z", "iopub.status.busy": "2021-02-20T16:18:30.700645Z", "iopub.status.idle": "2021-02-20T16:18:33.129174Z", "shell.execute_reply": "2021-02-20T16:18:33.128702Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "batch_size=200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "batch_size=200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "batch_size=200\n", "batch_size=200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "batch_size=200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "batch_size=200\n", "batch_size=200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "batch_size=200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "batch_size=200\n", "batch_size=200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "batch_size=200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "batch_size=200\n", "batch_size=200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "batch_size=200\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "batch_size=200\n" ] } ], "source": [ "_ = cross_val_score(clf, X, y, cv=5)" ] } ], "metadata": { "jupytext": { "formats": "ipynb,md" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.7" } }, "nbformat": 4, "nbformat_minor": 5 }