{
"cells": [
{
"cell_type": "raw",
"id": "a7a1f83b",
"metadata": {},
"source": [
"
Run in Google Colab"
]
},
{
"cell_type": "markdown",
"id": "55c76a26",
"metadata": {},
"source": [
"# Meta Estimators in SciKeras\n",
"\n",
"In this notebook, we implement sklearn ensemble and tree meta-estimators backed by a Keras MLP model.\n",
"\n",
"## Table of contents\n",
"\n",
"* [1. Setup](#1.-Setup)\n",
"* [2. Defining the Keras Model](#2.-Defining-the-Keras-Model)\n",
" * [2.1 Building a boosting ensemble](#2.1-Building-a-boosting-ensemble)\n",
"* [3. Testing with a toy dataset](#3.-Testing-with-a-toy-dataset)\n",
"* [4. Bagging ensemble](#4.-Bagging-ensemble)\n",
"\n",
"## 1. Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "f7a9aeda",
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-28T16:19:19.984494Z",
"iopub.status.busy": "2023-06-28T16:19:19.982576Z",
"iopub.status.idle": "2023-06-28T16:19:22.602122Z",
"shell.execute_reply": "2023-06-28T16:19:22.601350Z"
}
},
"outputs": [],
"source": [
"try:\n",
" import scikeras\n",
"except ImportError:\n",
" !python -m pip install scikeras[tensorflow]"
]
},
{
"cell_type": "markdown",
"id": "3d7e42a9",
"metadata": {},
"source": [
"Silence TensorFlow logging to keep output succinct."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0cea31fa",
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-28T16:19:22.610419Z",
"iopub.status.busy": "2023-06-28T16:19:22.608679Z",
"iopub.status.idle": "2023-06-28T16:19:22.614106Z",
"shell.execute_reply": "2023-06-28T16:19:22.613448Z"
}
},
"outputs": [],
"source": [
"import warnings\n",
"from tensorflow import get_logger\n",
"get_logger().setLevel('ERROR')\n",
"warnings.filterwarnings(\"ignore\", message=\"Setting the random state for TF\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "935b3e45",
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-28T16:19:22.617129Z",
"iopub.status.busy": "2023-06-28T16:19:22.616618Z",
"iopub.status.idle": "2023-06-28T16:19:22.996707Z",
"shell.execute_reply": "2023-06-28T16:19:22.994607Z"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"from scikeras.wrappers import KerasClassifier, KerasRegressor\n",
"from tensorflow import keras"
]
},
{
"cell_type": "markdown",
"id": "f3b9da80",
"metadata": {},
"source": [
"## 2. Defining the Keras Model\n",
"\n",
"We borrow our MLPClassifier implementation from the [MLPClassifier notebook](https://colab.research.google.com/github/adriangb/scikeras/blob/master/notebooks/MLPClassifier_and_MLPRegressor.ipynb)."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3185af2b",
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-28T16:19:23.004098Z",
"iopub.status.busy": "2023-06-28T16:19:23.000847Z",
"iopub.status.idle": "2023-06-28T16:19:23.015330Z",
"shell.execute_reply": "2023-06-28T16:19:23.013955Z"
}
},
"outputs": [],
"source": [
"from typing import Dict, Iterable, Any\n",
"\n",
"\n",
"def get_clf_model(hidden_layer_sizes: Iterable[int], meta: Dict[str, Any], compile_kwargs: Dict[str, Any]):\n",
" model = keras.Sequential()\n",
" inp = keras.layers.Input(shape=(meta[\"n_features_in_\"]))\n",
" model.add(inp)\n",
" for hidden_layer_size in hidden_layer_sizes:\n",
" layer = keras.layers.Dense(hidden_layer_size, activation=\"relu\")\n",
" model.add(layer)\n",
" if meta[\"target_type_\"] == \"binary\":\n",
" n_output_units = 1\n",
" output_activation = \"sigmoid\"\n",
" loss = \"binary_crossentropy\"\n",
" elif meta[\"target_type_\"] == \"multiclass\":\n",
" n_output_units = meta[\"n_classes_\"]\n",
" output_activation = \"softmax\"\n",
" loss = \"sparse_categorical_crossentropy\"\n",
" else:\n",
" raise NotImplementedError(f\"Unsupported task type: {meta['target_type_']}\")\n",
" out = keras.layers.Dense(n_output_units, activation=output_activation)\n",
" model.add(out)\n",
" model.compile(loss=loss, optimizer=compile_kwargs[\"optimizer\"])\n",
" return model"
]
},
{
"cell_type": "markdown",
"id": "4f52c330",
"metadata": {},
"source": [
"Next we wrap this Keras model with SciKeras"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e729f67e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-28T16:19:23.020081Z",
"iopub.status.busy": "2023-06-28T16:19:23.019156Z",
"iopub.status.idle": "2023-06-28T16:19:23.023996Z",
"shell.execute_reply": "2023-06-28T16:19:23.023312Z"
}
},
"outputs": [],
"source": [
"clf = KerasClassifier(\n",
" model=get_clf_model,\n",
" hidden_layer_sizes=(100, ),\n",
" optimizer=\"adam\",\n",
" optimizer__learning_rate=0.001,\n",
" verbose=0,\n",
" random_state=0,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "9a138419",
"metadata": {},
"source": [
"### 2.1 Building a boosting ensemble\n",
"\n",
"Because SciKeras estimators are fully compliant with the Scikit-Learn API, we can make use of Scikit-Learn's built in utilities. In particular example, we will use `AdaBoostClassifier` from `sklearn.ensemble.AdaBoostClassifier`, but the process is the same for most Scikit-Learn meta-estimators.\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c487aa1e",
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-28T16:19:23.029096Z",
"iopub.status.busy": "2023-06-28T16:19:23.027159Z",
"iopub.status.idle": "2023-06-28T16:19:23.591280Z",
"shell.execute_reply": "2023-06-28T16:19:23.590549Z"
}
},
"outputs": [],
"source": [
"from sklearn.ensemble import AdaBoostClassifier\n",
"\n",
"\n",
"adaboost = AdaBoostClassifier(base_estimator=clf, random_state=0)"
]
},
{
"cell_type": "markdown",
"id": "35412b2e",
"metadata": {},
"source": [
"## 3. Testing with a toy dataset\n",
"\n",
"Before continouing, we will run a small test to make sure we get somewhat reasonable results.\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "35f1faa5",
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-28T16:19:23.597280Z",
"iopub.status.busy": "2023-06-28T16:19:23.595965Z",
"iopub.status.idle": "2023-06-28T16:20:17.369449Z",
"shell.execute_reply": "2023-06-28T16:20:17.368139Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/runner/work/scikeras/scikeras/.venv/lib/python3.8/site-packages/sklearn/ensemble/_base.py:166: FutureWarning: `base_estimator` was renamed to `estimator` in version 1.2 and will be removed in 1.4.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Single score: 0.53\n",
"AdaBoost score: 0.87\n"
]
}
],
"source": [
"from sklearn.datasets import make_moons\n",
"\n",
"\n",
"X, y = make_moons()\n",
"\n",
"single_score = clf.fit(X, y).score(X, y)\n",
"\n",
"adaboost_score = adaboost.fit(X, y).score(X, y)\n",
"\n",
"print(f\"Single score: {single_score:.2f}\")\n",
"print(f\"AdaBoost score: {adaboost_score:.2f}\")"
]
},
{
"cell_type": "markdown",
"id": "94b59342",
"metadata": {},
"source": [
"We see that the score for the AdaBoost classifier is slightly higher than that of an individual MLPRegressor instance. We can explore the individual classifiers, and see that each one is composed of a Keras Model with it's own individual weights.\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6eb305f4",
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-28T16:20:17.374634Z",
"iopub.status.busy": "2023-06-28T16:20:17.373329Z",
"iopub.status.idle": "2023-06-28T16:20:17.391100Z",
"shell.execute_reply": "2023-06-28T16:20:17.390092Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.00950615 0.04678571 0.0866185 -0.20479605 0.11132256]\n",
"[ 0.2300349 0.20139852 0.21442255 0.03159269 -0.0991485 ]\n"
]
}
],
"source": [
"print(adaboost.estimators_[0].model_.get_weights()[0][0, :5]) # first sub-estimator\n",
"print(adaboost.estimators_[1].model_.get_weights()[0][0, :5]) # second sub-estimator"
]
},
{
"cell_type": "markdown",
"id": "24213a6a",
"metadata": {},
"source": [
"## 4. Bagging ensemble\n",
"\n",
"For comparison, we run the same test with an ensemble built using `sklearn.ensemble.BaggingClassifier`."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "679f5ba1",
"metadata": {
"execution": {
"iopub.execute_input": "2023-06-28T16:20:17.397745Z",
"iopub.status.busy": "2023-06-28T16:20:17.396201Z",
"iopub.status.idle": "2023-06-28T16:20:48.165340Z",
"shell.execute_reply": "2023-06-28T16:20:48.164428Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/runner/work/scikeras/scikeras/.venv/lib/python3.8/site-packages/sklearn/ensemble/_base.py:166: FutureWarning: `base_estimator` was renamed to `estimator` in version 1.2 and will be removed in 1.4.\n",
" warnings.warn(\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:tensorflow:5 out of the last 9 calls to .train_function at 0x7ff9d3c8c4c0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
"WARNING:tensorflow:5 out of the last 10 calls to .train_function at 0x7f5b5e0eed30> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n",
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n",
"WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n",
"WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).keras_api.metrics.0.total\n",
"WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).keras_api.metrics.0.total\n",
"WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).keras_api.metrics.0.count\n",
"WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).keras_api.metrics.0.count\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n",
"WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n",
"WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).keras_api.metrics.0.total\n",
"WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).keras_api.metrics.0.total\n",
"WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).keras_api.metrics.0.count\n",
"WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).keras_api.metrics.0.count\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Bagging score: 0.74\n"
]
}
],
"source": [
"from sklearn.ensemble import BaggingClassifier\n",
"\n",
"\n",
"bagging = BaggingClassifier(base_estimator=clf, random_state=0, n_jobs=-1)\n",
"\n",
"bagging_score = bagging.fit(X, y).score(X, y)\n",
"\n",
"print(f\"Bagging score: {bagging_score:.2f}\")"
]
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,md"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
}
},
"nbformat": 4,
"nbformat_minor": 5
}