{ "cells": [ { "cell_type": "raw", "id": "dbfc17f3", "metadata": {}, "source": [ "Run in Google Colab" ] }, { "cell_type": "markdown", "id": "fcbc5690", "metadata": {}, "source": [ "# Basic usage\n", "\n", "`SciKeras` is designed to maximize interoperability between `sklearn` and `Keras/TensorFlow`. The aim is to keep 99% of the flexibility of `Keras` while being able to leverage most features of `sklearn`. Below, we show the basic usage of `SciKeras` and how it can be combined with `sklearn`.\n", "\n", "This notebook shows you how to use the basic functionality of `SciKeras`.\n", "\n", "## Table of contents\n", "\n", "* [1. Setup](#1.-Setup)\n", "* [2. Training a classifier and making predictions](#2.-Training-a-classifier-and-making-predictions)\n", " * [2.1 A toy binary classification task](#2.1-A-toy-binary-classification-task)\n", " * [2.2 Definition of the Keras classification Model](#2.2-Definition-of-the-Keras-classification-Model)\n", " * [2.3 Defining and training the neural net classifier](#2.3-Defining-and-training-the-neural-net-classifier)\n", " * [2.4 Making predictions, classification](#2.4-Making-predictions-classification)\n", "* [3 Training a regressor](#3.-Training-a-regressor)\n", " * [3.1 A toy regression task](#3.1-A-toy-regression-task)\n", " * [3.2 Definition of the Keras regression Model](#3.2-Definition-of-the-Keras-regression-Model)\n", " * [3.3 Defining and training the neural net regressor](#3.3-Defining-and-training-the-neural-net-regressor)\n", " * [3.4 Making predictions, regression](#3.4-Making-predictions-regression)\n", "* [4. Saving and loading a model](#4.-Saving-and-loading-a-model)\n", " * [4.1 Saving the whole model](#4.1-Saving-the-whole-model)\n", " * [4.2 Saving using Keras' saving methods](#4.2-Saving-using-Keras-saving-methods)\n", "* [5. Usage with an sklearn Pipeline](#5.-Usage-with-an-sklearn-Pipeline)\n", "* [6. Callbacks](#6.-Callbacks)\n", "* [7. Usage with sklearn GridSearchCV](#7.-Usage-with-sklearn-GridSearchCV)\n", " * [7.1 Special prefixes](#7.1-Special-prefixes)\n", " * [7.2 Performing a grid search](#7.2-Performing-a-grid-search)\n", "\n", "## 1. Setup" ] }, { "cell_type": "code", "execution_count": 1, "id": "ea482e2d", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:20.157742Z", "iopub.status.busy": "2024-12-12T21:41:20.157253Z", "iopub.status.idle": "2024-12-12T21:41:27.768285Z", "shell.execute_reply": "2024-12-12T21:41:27.767583Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1734039680.907708 2099 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1734039680.910718 2099 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] } ], "source": [ "try:\n", " import scikeras\n", "except ImportError:\n", " !python -m pip install scikeras" ] }, { "cell_type": "markdown", "id": "5bfa4f2f", "metadata": {}, "source": [ "Silence TensorFlow logging to keep output succinct." ] }, { "cell_type": "code", "execution_count": 2, "id": "49efb8b5", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:27.771115Z", "iopub.status.busy": "2024-12-12T21:41:27.770420Z", "iopub.status.idle": "2024-12-12T21:41:27.774709Z", "shell.execute_reply": "2024-12-12T21:41:27.773903Z" } }, "outputs": [], "source": [ "import warnings\n", "from tensorflow import get_logger\n", "get_logger().setLevel('ERROR')\n", "warnings.filterwarnings(\"ignore\", message=\"Setting the random state for TF\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "e0ded38d", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:27.776888Z", "iopub.status.busy": "2024-12-12T21:41:27.776500Z", "iopub.status.idle": "2024-12-12T21:41:29.165204Z", "shell.execute_reply": "2024-12-12T21:41:29.164423Z" } }, "outputs": [], "source": [ "import numpy as np\n", "from scikeras.wrappers import KerasClassifier, KerasRegressor\n", "import keras" ] }, { "cell_type": "markdown", "id": "988f74be", "metadata": {}, "source": [ "## 2. Training a classifier and making predictions\n", "\n", "### 2.1 A toy binary classification task\n", "\n", "We load a toy classification task from `sklearn`." ] }, { "cell_type": "code", "execution_count": 4, "id": "08757ad6", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:29.167603Z", "iopub.status.busy": "2024-12-12T21:41:29.167280Z", "iopub.status.idle": "2024-12-12T21:41:29.300290Z", "shell.execute_reply": "2024-12-12T21:41:29.299703Z" } }, "outputs": [ { "data": { "text/plain": [ "((1000, 20), (1000,), np.float64(0.5))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "from sklearn.datasets import make_classification\n", "\n", "\n", "X, y = make_classification(1000, 20, n_informative=10, random_state=0)\n", "\n", "X.shape, y.shape, y.mean()" ] }, { "cell_type": "markdown", "id": "0edaefe4", "metadata": {}, "source": [ "### 2.2 Definition of the Keras classification Model\n", "\n", "We define a vanilla neural network with.\n", "\n", "Because we are dealing with 2 classes, the output layer can be constructed in\n", "two different ways:\n", "\n", "1. Single unit with a `\"sigmoid\"` nonlinearity. The loss must be `\"binary_crossentropy\"`.\n", "2. Two units (one for each class) and a `\"softmax\"` nonlinearity. The loss must be `\"sparse_categorical_crossentropy\"`.\n", "\n", "In this example, we choose the first option, which is what you would usually\n", "do for binary classification. The second option is usually reserved for when\n", "you have >2 classes." ] }, { "cell_type": "code", "execution_count": 5, "id": "2443adc5", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:29.302493Z", "iopub.status.busy": "2024-12-12T21:41:29.301998Z", "iopub.status.idle": "2024-12-12T21:41:29.306103Z", "shell.execute_reply": "2024-12-12T21:41:29.305492Z" } }, "outputs": [], "source": [ "import keras\n", "\n", "\n", "def get_clf(meta, hidden_layer_sizes, dropout):\n", " n_features_in_ = meta[\"n_features_in_\"]\n", " n_classes_ = meta[\"n_classes_\"]\n", " model = keras.models.Sequential()\n", " model.add(keras.layers.Input(shape=(n_features_in_,)))\n", " for hidden_layer_size in hidden_layer_sizes:\n", " model.add(keras.layers.Dense(hidden_layer_size, activation=\"relu\"))\n", " model.add(keras.layers.Dropout(dropout))\n", " model.add(keras.layers.Dense(1, activation=\"sigmoid\"))\n", " return model" ] }, { "cell_type": "markdown", "id": "996e8b40", "metadata": {}, "source": [ "### 2.3 Defining and training the neural net classifier\n", "\n", "We use `KerasClassifier` because we're dealing with a classifcation task. The first argument should be a callable returning a `Keras.Model`, in this case, `get_clf`. As additional arguments, we pass the number of loss function (required) and the optimizer, but the later is optional. We must also pass all of the arguments to `get_clf` as keyword arguments to `KerasClassifier` if they don't have a default value in `get_clf`. Note that if you do not pass an argument to `KerasClassifier`, it will not be avilable for hyperparameter tuning. Finally, we also pass `random_state=0` for reproducible results." ] }, { "cell_type": "code", "execution_count": 6, "id": "aafcf0cd", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:29.308323Z", "iopub.status.busy": "2024-12-12T21:41:29.307943Z", "iopub.status.idle": "2024-12-12T21:41:29.310782Z", "shell.execute_reply": "2024-12-12T21:41:29.310324Z" } }, "outputs": [], "source": [ "from scikeras.wrappers import KerasClassifier\n", "\n", "\n", "clf = KerasClassifier(\n", " model=get_clf,\n", " loss=\"binary_crossentropy\",\n", " hidden_layer_sizes=(100,),\n", " dropout=0.5,\n", ")" ] }, { "cell_type": "markdown", "id": "b3a5f38f", "metadata": {}, "source": [ "As in `sklearn`, we call `fit` passing the input data `X` and the targets `y`." ] }, { "cell_type": "code", "execution_count": 7, "id": "500ca846", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:29.312932Z", "iopub.status.busy": "2024-12-12T21:41:29.312591Z", "iopub.status.idle": "2024-12-12T21:41:29.949883Z", "shell.execute_reply": "2024-12-12T21:41:29.949213Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m15s\u001b[0m 495ms/step - loss: 1.0027" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m 2/32\u001b[0m \u001b[32m━\u001b[0m\u001b[37m━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 4ms/step - loss: 0.9915 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 993us/step - loss: 0.8469\n" ] } ], "source": [ "clf.fit(X, y);" ] }, { "cell_type": "markdown", "id": "d9ddcf42", "metadata": {}, "source": [ "Also, as in `sklearn`, you may call `predict` or `predict_proba` on the fitted model.\n", "\n", "### 2.4 Making predictions, classification" ] }, { "cell_type": "code", "execution_count": 8, "id": "ed59ce35", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:29.952948Z", "iopub.status.busy": "2024-12-12T21:41:29.952472Z", "iopub.status.idle": "2024-12-12T21:41:30.023338Z", "shell.execute_reply": "2024-12-12T21:41:30.022704Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 32ms/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 33ms/step\n" ] }, { "data": { "text/plain": [ "array([1, 0, 0, 0, 0])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_pred = clf.predict(X[:5])\n", "y_pred" ] }, { "cell_type": "code", "execution_count": 9, "id": "37c8158f", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:30.025556Z", "iopub.status.busy": "2024-12-12T21:41:30.025155Z", "iopub.status.idle": "2024-12-12T21:41:30.077300Z", "shell.execute_reply": "2024-12-12T21:41:30.076601Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 14ms/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 15ms/step\n" ] }, { "data": { "text/plain": [ "array([[0.31051397, 0.689486 ],\n", " [0.720068 , 0.27993205],\n", " [0.77182484, 0.2281752 ],\n", " [0.80232203, 0.19767794],\n", " [0.89132524, 0.10867473]], dtype=float32)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_proba = clf.predict_proba(X[:5])\n", "y_proba" ] }, { "cell_type": "markdown", "id": "420e0e44", "metadata": {}, "source": [ "## 3 Training a regressor\n", "\n", "### 3.1 A toy regression task" ] }, { "cell_type": "code", "execution_count": 10, "id": "7b5ea4c6", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:30.079738Z", "iopub.status.busy": "2024-12-12T21:41:30.079146Z", "iopub.status.idle": "2024-12-12T21:41:30.085753Z", "shell.execute_reply": "2024-12-12T21:41:30.085186Z" } }, "outputs": [ { "data": { "text/plain": [ "((1000, 20),\n", " (1000,),\n", " np.float64(-649.0148244404172),\n", " np.float64(615.4505181286091))" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.datasets import make_regression\n", "\n", "\n", "X_regr, y_regr = make_regression(1000, 20, n_informative=10, random_state=0)\n", "\n", "X_regr.shape, y_regr.shape, y_regr.min(), y_regr.max()" ] }, { "cell_type": "markdown", "id": "1bbc60de", "metadata": {}, "source": [ "### 3.2 Definition of the Keras regression Model\n", "\n", "Again, define a vanilla neural network. The main difference is that the output layer always has a single unit and does not apply any nonlinearity." ] }, { "cell_type": "code", "execution_count": 11, "id": "d5260ef4", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:30.087712Z", "iopub.status.busy": "2024-12-12T21:41:30.087281Z", "iopub.status.idle": "2024-12-12T21:41:30.091826Z", "shell.execute_reply": "2024-12-12T21:41:30.091140Z" } }, "outputs": [], "source": [ "def get_reg(meta, hidden_layer_sizes, dropout):\n", " n_features_in_ = meta[\"n_features_in_\"]\n", " model = keras.models.Sequential()\n", " model.add(keras.layers.Input(shape=(n_features_in_,)))\n", " for hidden_layer_size in hidden_layer_sizes:\n", " model.add(keras.layers.Dense(hidden_layer_size, activation=\"relu\"))\n", " model.add(keras.layers.Dropout(dropout))\n", " model.add(keras.layers.Dense(1))\n", " return model" ] }, { "cell_type": "markdown", "id": "ee8eb2d0", "metadata": {}, "source": [ "### 3.3 Defining and training the neural net regressor\n", "\n", "Training a regressor has nearly the same data flow as training a classifier. The differences include using `KerasRegressor` instead of `KerasClassifier` and adding `keras.metrics.R2Score` as a metric. Most of the Scikit-learn regressors use the coefficient of determination or R^2 as a metric function, which measures correlation between the true labels and predicted labels." ] }, { "cell_type": "code", "execution_count": 12, "id": "2292f9fc", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:30.093879Z", "iopub.status.busy": "2024-12-12T21:41:30.093364Z", "iopub.status.idle": "2024-12-12T21:41:30.097633Z", "shell.execute_reply": "2024-12-12T21:41:30.096922Z" } }, "outputs": [], "source": [ "import keras\n", "import keras.models\n", "from scikeras.wrappers import KerasRegressor\n", "\n", "\n", "reg = KerasRegressor(\n", " model=get_reg,\n", " loss=\"mse\",\n", " metrics=[keras.metrics.R2Score],\n", " hidden_layer_sizes=(100,),\n", " dropout=0.5,\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "c4b56f6c", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:30.099984Z", "iopub.status.busy": "2024-12-12T21:41:30.099500Z", "iopub.status.idle": "2024-12-12T21:41:30.909607Z", "shell.execute_reply": "2024-12-12T21:41:30.908890Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m22s\u001b[0m 710ms/step - loss: 50730.8281 - r2_score: -0.0202" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m 2/32\u001b[0m \u001b[32m━\u001b[0m\u001b[37m━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 3ms/step - loss: 51634.2109 - r2_score: -0.0163 " ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m 3/32\u001b[0m \u001b[32m━\u001b[0m\u001b[37m━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 3ms/step - loss: 53001.1055 - r2_score: -0.0111" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 1ms/step - loss: 46382.5273 - r2_score: -0.0051\n" ] } ], "source": [ "reg.fit(X_regr, y_regr);" ] }, { "cell_type": "markdown", "id": "d5fc87f5", "metadata": {}, "source": [ "### 3.4 Making predictions, regression\n", "\n", "You may call `predict` or `predict_proba` on the fitted model. For regressions, both methods return the same value." ] }, { "cell_type": "code", "execution_count": 14, "id": "aff145df", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:30.912369Z", "iopub.status.busy": "2024-12-12T21:41:30.911707Z", "iopub.status.idle": "2024-12-12T21:41:30.994921Z", "shell.execute_reply": "2024-12-12T21:41:30.994231Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 38ms/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 39ms/step\n" ] }, { "data": { "text/plain": [ "array([ 1.9931117 , -0.07530521, 1.3747294 , 0.2566314 , 0.79165256],\n", " dtype=float32)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_pred = reg.predict(X_regr[:5])\n", "y_pred" ] }, { "cell_type": "markdown", "id": "7bf1990f", "metadata": {}, "source": [ "## 4. Saving and loading a model\n", "\n", "Save and load either the whole model by using pickle, or use Keras' specialized save methods on the `KerasClassifier.model_` or `KerasRegressor.model_` attribute that is created after fitting. You will want to use Keras' model saving utilities if any of the following apply:\n", "\n", "1. You wish to save only the weights or only the training configuration of your model.\n", "2. You wish to share your model with collaborators. Pickle is a relatively unsafe protocol and it is not recommended to share or load pickle objects publically.\n", "3. You care about performance, especially if doing in-memory serialization.\n", "\n", "For more information, see Keras' [saving documentation](https://www.tensorflow.org/guide/keras/save_and_serialize).\n", "\n", "### 4.1 Saving the whole model" ] }, { "cell_type": "code", "execution_count": 15, "id": "dd63fe7e", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:30.997505Z", "iopub.status.busy": "2024-12-12T21:41:30.997253Z", "iopub.status.idle": "2024-12-12T21:41:31.153645Z", "shell.execute_reply": "2024-12-12T21:41:31.152881Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 35ms/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 36ms/step\n" ] }, { "data": { "text/plain": [ "array([ 1.9931117 , -0.07530521, 1.3747294 , 0.2566314 , 0.79165256],\n", " dtype=float32)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pickle\n", "\n", "\n", "bytes_model = pickle.dumps(reg)\n", "new_reg = pickle.loads(bytes_model)\n", "new_reg.predict(X_regr[:5]) # model is still trained" ] }, { "cell_type": "markdown", "id": "3e93fc2c", "metadata": {}, "source": [ "### 4.2 Saving using Keras' saving methods\n", "\n", "This efficiently and safely saves the model to disk, including trained weights.\n", "You should use this method if you plan on sharing your saved models." ] }, { "cell_type": "code", "execution_count": 16, "id": "2c012502", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:31.156464Z", "iopub.status.busy": "2024-12-12T21:41:31.155914Z", "iopub.status.idle": "2024-12-12T21:41:31.303316Z", "shell.execute_reply": "2024-12-12T21:41:31.302483Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m1s\u001b[0m 43ms/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 746us/step\n" ] } ], "source": [ "# Save to disk\n", "pred_old = reg.predict(X_regr)\n", "reg.model_.save(\"/tmp/my_model.keras\") # saves just the Keras model" ] }, { "cell_type": "code", "execution_count": 17, "id": "10761cb5", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:31.305518Z", "iopub.status.busy": "2024-12-12T21:41:31.305323Z", "iopub.status.idle": "2024-12-12T21:41:31.499790Z", "shell.execute_reply": "2024-12-12T21:41:31.498981Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m1s\u001b[0m 42ms/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 1ms/step \n" ] } ], "source": [ "# Load the model back into memory\n", "new_reg_model = keras.saving.load_model(\"/tmp/my_model.keras\")\n", "# Now we need to instantiate a new SciKeras object\n", "# since we only saved the Keras model\n", "reg_new = KerasRegressor(new_reg_model)\n", "# use initialize to avoid re-fitting\n", "reg_new.initialize(X_regr, y_regr)\n", "pred_new = reg_new.predict(X_regr)\n", "np.testing.assert_allclose(pred_old, pred_new)" ] }, { "cell_type": "markdown", "id": "6beed768", "metadata": {}, "source": [ "## 5. Usage with an sklearn Pipeline\n", "\n", "It is possible to put the `KerasClassifier` inside an `sklearn Pipeline`, as you would with any `sklearn` classifier.\n" ] }, { "cell_type": "code", "execution_count": 18, "id": "57de5cc7", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:31.502537Z", "iopub.status.busy": "2024-12-12T21:41:31.502087Z", "iopub.status.idle": "2024-12-12T21:41:32.202800Z", "shell.execute_reply": "2024-12-12T21:41:32.202151Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m15s\u001b[0m 484ms/step - loss: 0.7475" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 829us/step - loss: 0.7135 \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "\u001b[1m 1/32\u001b[0m \u001b[37m━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 29ms/step" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", "\u001b[1m32/32\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step \n" ] } ], "source": [ "from sklearn.pipeline import Pipeline\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "\n", "pipe = Pipeline([\n", " ('scale', StandardScaler()),\n", " ('clf', clf),\n", "])\n", "\n", "\n", "y_proba = pipe.fit(X, y).predict(X)" ] }, { "cell_type": "markdown", "id": "7dba90ba", "metadata": {}, "source": [ "To save the whole pipeline, including the Keras model, use `pickle`.\n", "\n", "## 6. Callbacks\n", "\n", "Adding a new callback to the model is straightforward. Below we define a threashold callback\n", "to avoid training past a certain accuracy. This a rudimentary for of\n", "[early stopping](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping)." ] }, { "cell_type": "code", "execution_count": 19, "id": "43b5b9e3", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:32.205274Z", "iopub.status.busy": "2024-12-12T21:41:32.204723Z", "iopub.status.idle": "2024-12-12T21:41:32.209349Z", "shell.execute_reply": "2024-12-12T21:41:32.208793Z" } }, "outputs": [], "source": [ "class MaxValLoss(keras.callbacks.Callback):\n", "\n", " def __init__(self, monitor: str, threashold: float):\n", " self.monitor = monitor\n", " self.threashold = threashold\n", "\n", " def on_epoch_end(self, epoch, logs=None):\n", " if logs[self.monitor] > self.threashold:\n", " print(\"Threashold reached; stopping training\") \n", " self.model.stop_training = True" ] }, { "cell_type": "markdown", "id": "576738f8", "metadata": {}, "source": [ "Define a test dataset:" ] }, { "cell_type": "code", "execution_count": 20, "id": "dbd5ac49", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:32.211733Z", "iopub.status.busy": "2024-12-12T21:41:32.211346Z", "iopub.status.idle": "2024-12-12T21:41:32.215408Z", "shell.execute_reply": "2024-12-12T21:41:32.214723Z" } }, "outputs": [], "source": [ "from sklearn.datasets import make_moons\n", "\n", "\n", "X, y = make_moons(n_samples=100, noise=0.2, random_state=0)" ] }, { "cell_type": "markdown", "id": "7b0aa5e3", "metadata": {}, "source": [ "And try fitting it with and without the callback:" ] }, { "cell_type": "code", "execution_count": 21, "id": "51f508df", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:32.218083Z", "iopub.status.busy": "2024-12-12T21:41:32.217398Z", "iopub.status.idle": "2024-12-12T21:41:33.873284Z", "shell.execute_reply": "2024-12-12T21:41:33.872632Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Trained 20 epochs\n", "Final accuracy: 0.8999999761581421\n" ] } ], "source": [ "kwargs = dict(\n", " model=get_clf,\n", " loss=\"binary_crossentropy\",\n", " dropout=0.5,\n", " hidden_layer_sizes=(100,),\n", " metrics=[\"binary_accuracy\"],\n", " fit__validation_split=0.2,\n", " epochs=20,\n", " verbose=False,\n", " random_state=0\n", ")\n", "\n", "# First test without the callback\n", "clf = KerasClassifier(**kwargs)\n", "clf.fit(X, y)\n", "print(f\"Trained {len(clf.history_['loss'])} epochs\")\n", "print(f\"Final accuracy: {clf.history_['val_binary_accuracy'][-1]}\") # get last value of last fit/partial_fit call" ] }, { "cell_type": "markdown", "id": "e5c8f626", "metadata": {}, "source": [ "And with:" ] }, { "cell_type": "code", "execution_count": 22, "id": "72ab8109", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:33.876493Z", "iopub.status.busy": "2024-12-12T21:41:33.875630Z", "iopub.status.idle": "2024-12-12T21:41:35.043552Z", "shell.execute_reply": "2024-12-12T21:41:35.042771Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Threashold reached; stopping training\n", "Trained 4 epochs\n", "Final accuracy: 0.8999999761581421\n" ] } ], "source": [ "# Test with the callback\n", "\n", "cb = MaxValLoss(monitor=\"val_binary_accuracy\", threashold=0.75)\n", "\n", "clf = KerasClassifier(\n", " **kwargs,\n", " callbacks=[cb]\n", ")\n", "clf.fit(X, y)\n", "print(f\"Trained {len(clf.history_['loss'])} epochs\")\n", "print(f\"Final accuracy: {clf.history_['val_binary_accuracy'][-1]}\") # get last value of last fit/partial_fit call" ] }, { "cell_type": "markdown", "id": "9322fec1", "metadata": {}, "source": [ "For information on how to write custom callbacks, have a look at the\n", "[Advanced Usage](https://nbviewer.jupyter.org/github/adriangb/scikeras/blob/master/notebooks/Advanced_Usage.ipynb) notebook.\n", "\n", "## 7. Usage with sklearn GridSearchCV\n", "\n", "### 7.1 Special prefixes\n", "\n", "SciKeras allows to direct access to all parameters passed to the wrapper constructors, including deeply nested routed parameters. This allows tunning of\n", "paramters like `hidden_layer_sizes` as well as `optimizer__learning_rate`.\n", "\n", "This is exactly the same logic that allows to access estimator parameters in `sklearn Pipeline`s and `FeatureUnion`s.\n", "\n", "This feature is useful in several ways. For one, it allows to set those parameters in the model definition. Furthermore, it allows you to set parameters in an `sklearn GridSearchCV` as shown below.\n", "\n", "To differentiate paramters like `callbacks` which are accepted by both `keras.Model.fit` and `keras.Model.predict` you can add a `fit__` or `predict__` routing suffix respectively. Similar, the `model__` prefix may be used to specify that a paramter is destined only for `get_clf`/`get_reg` (or whatever callable you pass as your `model` argument).\n", "\n", "For more information on parameter routing with special prefixes, see the [Advanced Usage Docs](https://www.adriangb.com/scikeras/stable/advanced.html#routed-parameters)\n", "\n", "### 7.2 Performing a grid search\n", "\n", "Below we show how to perform a grid search over the learning rate (`optimizer__learning_rate`), the model's number of hidden layers (`model__hidden_layer_sizes`), the model's dropout rate (`model__dropout`)." ] }, { "cell_type": "code", "execution_count": 23, "id": "4620890f", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:35.046218Z", "iopub.status.busy": "2024-12-12T21:41:35.045962Z", "iopub.status.idle": "2024-12-12T21:41:35.098002Z", "shell.execute_reply": "2024-12-12T21:41:35.097220Z" } }, "outputs": [], "source": [ "from sklearn.model_selection import GridSearchCV\n", "\n", "\n", "clf = KerasClassifier(\n", " model=get_clf,\n", " loss=\"binary_crossentropy\",\n", " optimizer=\"adam\",\n", " optimizer__learning_rate=0.1,\n", " model__hidden_layer_sizes=(100,),\n", " model__dropout=0.5,\n", " verbose=False,\n", ")" ] }, { "cell_type": "markdown", "id": "a68a90da", "metadata": {}, "source": [ "*Note*: We set the verbosity level to zero (`verbose=False`) to prevent too much print output from being shown." ] }, { "cell_type": "code", "execution_count": 24, "id": "93feb6cc", "metadata": { "execution": { "iopub.execute_input": "2024-12-12T21:41:35.100726Z", "iopub.status.busy": "2024-12-12T21:41:35.100452Z", "iopub.status.idle": "2024-12-12T21:42:00.901135Z", "shell.execute_reply": "2024-12-12T21:42:00.900375Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting 5 folds for each of 8 candidates, totalling 40 fits\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1734039696.998217 2658 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1734039697.013633 2658 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1734039697.205539 2659 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1734039697.215263 2659 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1734039697.353293 2664 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1734039697.363972 2664 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", "E0000 00:00:1734039697.607431 2657 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "E0000 00:00:1734039697.627304 2657 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:5 out of the last 13 calls to .one_step_on_iterator at 0x7ff7559860c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", "WARNING:tensorflow:5 out of the last 5 calls to .one_step_on_data_distributed at 0x7ff75587db20> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:5 out of the last 13 calls to .one_step_on_iterator at 0x7f85058720c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", "WARNING:tensorflow:5 out of the last 5 calls to .one_step_on_data_distributed at 0x7f8505765b20> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:5 out of the last 13 calls to .one_step_on_iterator at 0x7f95539a5e40> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", "WARNING:tensorflow:5 out of the last 13 calls to .one_step_on_iterator at 0x7fd67d8f5ee0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", "WARNING:tensorflow:5 out of the last 5 calls to .one_step_on_data_distributed at 0x7f95538a1a80> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", "WARNING:tensorflow:5 out of the last 5 calls to .one_step_on_data_distributed at 0x7fd67d7edb20> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:5 out of the last 13 calls to .one_step_on_iterator at 0x7ff75577e5c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", "WARNING:tensorflow:6 out of the last 6 calls to .one_step_on_data_distributed at 0x7ff75560bc40> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:5 out of the last 13 calls to .one_step_on_iterator at 0x7f85056625c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", "WARNING:tensorflow:6 out of the last 6 calls to .one_step_on_data_distributed at 0x7f85054ebc40> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:5 out of the last 13 calls to .one_step_on_iterator at 0x7fd67d6ea520> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", "WARNING:tensorflow:5 out of the last 13 calls to .one_step_on_iterator at 0x7f95537a2480> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", "WARNING:tensorflow:6 out of the last 6 calls to .one_step_on_data_distributed at 0x7fd67d777ba0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING:tensorflow:6 out of the last 6 calls to .one_step_on_data_distributed at 0x7f955363fa60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0.85 {'model__dropout': 0, 'model__hidden_layer_sizes': (50, 50), 'optimizer__learning_rate': 0.1}\n" ] } ], "source": [ "params = {\n", " 'optimizer__learning_rate': [0.05, 0.1],\n", " 'model__hidden_layer_sizes': [(100, ), (50, 50, )],\n", " 'model__dropout': [0, 0.5],\n", "}\n", "\n", "gs = GridSearchCV(clf, params, scoring='accuracy', n_jobs=-1, verbose=True)\n", "\n", "gs.fit(X, y)\n", "\n", "print(gs.best_score_, gs.best_params_)" ] }, { "cell_type": "markdown", "id": "a331514d", "metadata": {}, "source": [ "Of course, we could further nest the `KerasClassifier` within an `sklearn.pipeline.Pipeline`,\n", "in which case we just prefix the parameter by the name of the net (e.g. `clf__model__hidden_layer_sizes`)." ] } ], "metadata": { "jupytext": { "formats": "ipynb,md" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 5 }