{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(sec-dask-ml-distributed-training)=\n", "# Distributed Machine Learning\n", "\n", "If the volume of training data is large, Dask-ML provides distributed machine learning capabilities that allow for training on big data across a cluster. Currently, Dask offers two types of distributed machine learning APIs:\n", "\n", "* scikit-learn style\n", "* XGBoost and LightGBM decision tree style\n", "\n", "## scikit-learn API\n", "\n", "Leveraging the distributed computing capabilities of Dask Array, Dask DataFrame, and Dask Delayed, Dask-ML has implemented distributed versions of machine learning algorithms, similar to scikit-learn. For example, in `dask_ml.linear_model`, there are linear regression [`LinearRegression`](https://ml.dask.org/modules/generated/dask_ml.linear_model.LinearRegression.html) and logistic regression [`LogisticRegression`](https://ml.dask.org/modules/generated/dask_ml.linear_model.LogisticRegression.html), and in `dask_ml.cluster`, there is [`KMeans`](https://ml.dask.org/modules/generated/dask_ml.cluster.KMeans.html). Dask-ML strives to keep the usage of these machine learning algorithms consistent with scikit-learn.\n", "\n", "On a Dask cluster consisting of 2 computing nodes, use the linear models from `dask_ml.linear_model`. Each computing node in this cluster has 90GiB of memory, and we randomly generate a dataset of 37GiB, which is then split into a training set and a testing set." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [ "hide-cell" ] }, "outputs": [], "source": [ "%config InlineBackend.figure_format = 'svg'\n", "import time\n", "\n", "import seaborn as sns\n", "import pandas as pd\n", "\n", "from dask.distributed import Client, LocalCluster" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import dask_ml.datasets\n", "import sklearn.linear_model\n", "import dask_ml.linear_model\n", "from dask_ml.model_selection import train_test_split" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Client

\n", "

Client-ad77e682-0ae4-11ef-8730-000012e4fe80

\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "
Connection method: Direct
\n", " Dashboard: http://10.0.0.3:43549/status\n", "
\n", "\n", " \n", "\n", " \n", "
\n", "

Scheduler Info

\n", "
\n", "
\n", "
\n", "
\n", "

Scheduler

\n", "

Scheduler-c7851ab9-9963-4c85-b394-bb74e8e2967f

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " Comm: tcp://10.0.0.3:8786\n", " \n", " Workers: 2\n", "
\n", " Dashboard: http://10.0.0.3:43549/status\n", " \n", " Total threads: 128\n", "
\n", " Started: 5 hours ago\n", " \n", " Total memory: 180.00 GiB\n", "
\n", "
\n", "
\n", "\n", "
\n", " \n", "

Workers

\n", "
\n", "\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: tcp://10.0.0.2:46501

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "
\n", " Comm: tcp://10.0.0.2:46501\n", " \n", " Total threads: 64\n", "
\n", " Dashboard: http://10.0.0.2:42539/status\n", " \n", " Memory: 90.00 GiB\n", "
\n", " Nanny: tcp://10.0.0.2:40241\n", "
\n", " Local directory: /tmp/dask-scratch-space/worker-rxylv59_\n", "
\n", " Tasks executing: \n", " \n", " Tasks in memory: \n", "
\n", " Tasks ready: \n", " \n", " Tasks in flight: \n", "
\n", " CPU usage: 6.0%\n", " \n", " Last seen: Just now\n", "
\n", " Memory usage: 301.68 MiB\n", " \n", " Spilled bytes: 0 B\n", "
\n", " Read bytes: 572.9739612289254 B\n", " \n", " Write bytes: 1.71 kiB\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: tcp://10.0.0.3:39997

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "
\n", " Comm: tcp://10.0.0.3:39997\n", " \n", " Total threads: 64\n", "
\n", " Dashboard: http://10.0.0.3:40955/status\n", " \n", " Memory: 90.00 GiB\n", "
\n", " Nanny: tcp://10.0.0.3:34825\n", "
\n", " Local directory: /tmp/dask-scratch-space/worker-kdphx4zv\n", "
\n", " Tasks executing: \n", " \n", " Tasks in memory: \n", "
\n", " Tasks ready: \n", " \n", " Tasks in flight: \n", "
\n", " CPU usage: 4.0%\n", " \n", " Last seen: Just now\n", "
\n", " Memory usage: 300.18 MiB\n", " \n", " Spilled bytes: 0 B\n", "
\n", " Read bytes: 8.27 kiB\n", " \n", " Write bytes: 10.57 kiB\n", "
\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
" ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "client = Client(\"10.0.0.3:8786\")\n", "client" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/base.py:1462: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.\n", " warnings.warn(\n" ] }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Array Chunk
Bytes 37.25 GiB 381.47 MiB
Shape (10000000, 500) (100000, 500)
Dask graph 100 chunks in 1 graph layer
Data type float64 numpy.ndarray
\n", "
\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " 500\n", " 10000000\n", "\n", "
" ], "text/plain": [ "dask.array" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X, y = dask_ml.datasets.make_classification(n_samples=10_000_000, \n", " n_features=500, \n", " random_state=42,\n", " chunks=10_000_000 // 100\n", ")\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)\n", "X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Call the scikit-learn sytle `fit()` method:" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/config.py:789: FutureWarning: Dask configuration key 'fuse_ave_width' has been deprecated; please use 'optimization.fuse.ave-width' instead\n", " warnings.warn(\n" ] } ], "source": [ "lr = dask_ml.linear_model.LogisticRegression(solver=\"lbfgs\").fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The trained model can be used for prediction (`predict()`), as well as for calculating accuracy (`score()`)." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ True, False, True, True, True])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_predicted = lr.predict(X_test)\n", "y_predicted[:5].compute()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.668674" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lr.score(X_test, y_test).compute()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If the same size of data is used to train a model on a single machine with scikit-learn, it would result in an error due to insufficient memory.\n", "\n", "Although the distributed training API of Dask-ML is extremely similar to scikit-learn, the fact that scikit-learn can only use one single core while Dask-ML can utilize multiple cores or even a cluster does not mean that Dask-ML should be chosen in all scenarios, as it is not always the best option in terms of performance or cost-effectiveness. This is similar to the relationship between Dask DataFrame and pandas; if the dataset can fit into the memory of a single machine, the performance and compatibility of native pandas, NumPy, and scikit-learn are always optimal.\n", "\n", "The following code performs a performance analysis on training data of different scales. In scenarios where the data volume is small and running on a single machine with multiple cores, the performance of Dask-ML is not faster than scikit-learn. There are many reasons for this, including:\n", "\n", "* Many machine learning algorithms are iterative. In scikit-learn, iterative algorithms are implemented using Python's native `for` loops; Dask-ML has adopted this approach of `for` loops. But for Dask's Task Graph, `for` loops can make the Task Graph quite bloated, and the execution efficiency is not very high.\n", "* The distributed implementation requires distributing and collecting data across different processes, which adds a lot of extra data synchronization and communication overhead compared to a single machine and single process.\n", "\n", "You can also test the performance based on the memory you have available." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Client

\n", "

Client-b4f64c31-0ae4-11ef-8730-000012e4fe80

\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "
Connection method: Cluster objectCluster type: distributed.LocalCluster
\n", " Dashboard: http://127.0.0.1:8787/status\n", "
\n", "\n", " \n", "\n", " \n", "
\n", "

Cluster Info

\n", "
\n", "
\n", "
\n", "
\n", "

LocalCluster

\n", "

1872fd25

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", "
\n", " Dashboard: http://127.0.0.1:8787/status\n", " \n", " Workers: 8\n", "
\n", " Total threads: 64\n", " \n", " Total memory: 90.00 GiB\n", "
Status: runningUsing processes: True
\n", "\n", "
\n", " \n", "

Scheduler Info

\n", "
\n", "\n", "
\n", "
\n", "
\n", "
\n", "

Scheduler

\n", "

Scheduler-897dca6c-6012-4df7-9a10-bd08f8810617

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " Comm: tcp://127.0.0.1:38477\n", " \n", " Workers: 8\n", "
\n", " Dashboard: http://127.0.0.1:8787/status\n", " \n", " Total threads: 64\n", "
\n", " Started: Just now\n", " \n", " Total memory: 90.00 GiB\n", "
\n", "
\n", "
\n", "\n", "
\n", " \n", "

Workers

\n", "
\n", "\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 0

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:44219\n", " \n", " Total threads: 8\n", "
\n", " Dashboard: http://127.0.0.1:36081/status\n", " \n", " Memory: 11.25 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:34355\n", "
\n", " Local directory: /tmp/dask-scratch-space/worker-439c1uaa\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 1

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:41549\n", " \n", " Total threads: 8\n", "
\n", " Dashboard: http://127.0.0.1:44857/status\n", " \n", " Memory: 11.25 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:41265\n", "
\n", " Local directory: /tmp/dask-scratch-space/worker-hyxlvh30\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 2

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:42877\n", " \n", " Total threads: 8\n", "
\n", " Dashboard: http://127.0.0.1:40235/status\n", " \n", " Memory: 11.25 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:40939\n", "
\n", " Local directory: /tmp/dask-scratch-space/worker-e70v3hq2\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 3

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:34321\n", " \n", " Total threads: 8\n", "
\n", " Dashboard: http://127.0.0.1:40295/status\n", " \n", " Memory: 11.25 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:35007\n", "
\n", " Local directory: /tmp/dask-scratch-space/worker-udlmb2zo\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 4

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:36039\n", " \n", " Total threads: 8\n", "
\n", " Dashboard: http://127.0.0.1:45691/status\n", " \n", " Memory: 11.25 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:34883\n", "
\n", " Local directory: /tmp/dask-scratch-space/worker-g5h5ob4b\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 5

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:35057\n", " \n", " Total threads: 8\n", "
\n", " Dashboard: http://127.0.0.1:43309/status\n", " \n", " Memory: 11.25 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:43945\n", "
\n", " Local directory: /tmp/dask-scratch-space/worker-61hsl1ap\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 6

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:36811\n", " \n", " Total threads: 8\n", "
\n", " Dashboard: http://127.0.0.1:44197/status\n", " \n", " Memory: 11.25 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:44607\n", "
\n", " Local directory: /tmp/dask-scratch-space/worker-syjczr8e\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 7

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:42081\n", " \n", " Total threads: 8\n", "
\n", " Dashboard: http://127.0.0.1:35819/status\n", " \n", " Memory: 11.25 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:33971\n", "
\n", " Local directory: /tmp/dask-scratch-space/worker-1rw7_3km\n", "
\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
\n", "\n", "
\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "client = Client(LocalCluster())\n", "client" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/base.py:1462: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.\n", " warnings.warn(\n", "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/config.py:789: FutureWarning: Dask configuration key 'fuse_ave_width' has been deprecated; please use 'optimization.fuse.ave-width' instead\n", " warnings.warn(\n", "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/base.py:1462: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.\n", " warnings.warn(\n", "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/config.py:789: FutureWarning: Dask configuration key 'fuse_ave_width' has been deprecated; please use 'optimization.fuse.ave-width' instead\n", " warnings.warn(\n", "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/base.py:1462: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.\n", " warnings.warn(\n", "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/config.py:789: FutureWarning: Dask configuration key 'fuse_ave_width' has been deprecated; please use 'optimization.fuse.ave-width' instead\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-05-05T21:42:17.956299\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.4, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "num_sample = [500_000, 1_000_000, 1_500_000]\n", "num_feature = 1_000\n", "timings = []\n", "\n", "for n in num_sample:\n", " X, y = dask_ml.datasets.make_classification(n_samples=n, \n", " n_features=num_feature, \n", " random_state=42,\n", " chunks=n // 10\n", " )\n", " t1 = time.time()\n", " sklearn.linear_model.LogisticRegression(solver=\"lbfgs\").fit(X, y)\n", " timings.append(('scikit-learn', n, time.time() - t1))\n", " t1 = time.time()\n", " dask_ml.linear_model.LogisticRegression(solver=\"lbfgs\").fit(X, y)\n", " timings.append(('dask-ml', n, time.time() - t1))\n", "\n", "df = pd.DataFrame(timings, columns=['method', '# of samples', 'time'])\n", "sns.barplot(data=df, x='# of samples', y='time', hue='method')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the context of logistic regression, Dask-ML does not offer significant advantages over scikit-learn when running on a single machine with multiple cores. Moreover, many traditional machine learning algorithms do not have high demands for large training data, and their performance does not increase significantly with the addition of more training data. The relationship between the volume of training data and model performance can be visualized through learning curves; for algorithms like Naive Bayes, performance improvements with increased training data are quite limited. If certain machine learning algorithms cannot be trained in a distributed manner or if the cost of distributed training is high, it might be worth considering sampling the training data to a size that fits into the memory of a single machine and using a single-machine framework like scikit-learn.\n", "\n", "In summary, if you have training data that exceeds the memory capacity of a single machine, you should consider various factors.\n", "\n", "## XGBoost and LightGBM\n", "\n", "XGBoost and LightGBM are two implementations of decision tree models that are inherently friendly to distributed training and have been integrated with Dask. Below, we use an example to illustrate how to use Dask and XGBoost for distributed training, and LightGBM is similar.\n", "\n", "In XGBoost, training a model can be done using either the `train()` method or the scikit-learn-style `fit()` method. Both approaches support Dask distributed training.\n", "\n", "The code below compares the performance of single-machine XGBoost and Dask distributed training. When using Dask, users need to change [`xgboost.DMatrix`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.DMatrix) to [`xgboost.dask.DaskDMatrix`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.dask.DaskDMatrix), where `xgboost.dask.DaskDMatrix` can convert distributed Dask Arrays or Dask DataFrames into the data format required by XGBoost. Users also need to replace [`xgboost.train()`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.train) with [`xgboost.dask.train()`](https://xgboost.readthedocs.io/en/stable/python/python_api.html#xgboost.dask.train) and pass in the Dask cluster client `client`." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/base.py:1462: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.\n", " warnings.warn(\n", "[22:13:43] task [xgboost.dask-0]:tcp://127.0.0.1:44219 got new rank 0\n", "[22:13:43] task [xgboost.dask-1]:tcp://127.0.0.1:41549 got new rank 1\n", "[22:13:43] task [xgboost.dask-2]:tcp://127.0.0.1:42877 got new rank 2\n", "[22:13:43] task [xgboost.dask-3]:tcp://127.0.0.1:34321 got new rank 3\n", "[22:13:43] task [xgboost.dask-4]:tcp://127.0.0.1:36039 got new rank 4\n", "[22:13:43] task [xgboost.dask-5]:tcp://127.0.0.1:35057 got new rank 5\n", "[22:13:43] task [xgboost.dask-6]:tcp://127.0.0.1:36811 got new rank 6\n", "[22:13:43] task [xgboost.dask-7]:tcp://127.0.0.1:42081 got new rank 7\n", "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask/base.py:1462: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.\n", " warnings.warn(\n", "[22:16:27] task [xgboost.dask-0]:tcp://127.0.0.1:44219 got new rank 0\n", "[22:16:27] task [xgboost.dask-1]:tcp://127.0.0.1:41549 got new rank 1\n", "[22:16:27] task [xgboost.dask-2]:tcp://127.0.0.1:42877 got new rank 2\n", "[22:16:27] task [xgboost.dask-3]:tcp://127.0.0.1:34321 got new rank 3\n", "[22:16:27] task [xgboost.dask-4]:tcp://127.0.0.1:36039 got new rank 4\n", "[22:16:27] task [xgboost.dask-5]:tcp://127.0.0.1:35057 got new rank 5\n", "[22:16:28] task [xgboost.dask-6]:tcp://127.0.0.1:36811 got new rank 6\n", "[22:16:28] task [xgboost.dask-7]:tcp://127.0.0.1:42081 got new rank 7\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2024-05-05T22:16:40.172826\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.8.4, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import xgboost as xgb\n", "\n", "num_sample = [100_000, 500_000]\n", "num_feature = 1_000\n", "xgb_timings = []\n", "\n", "for n in num_sample:\n", " X, y = dask_ml.datasets.make_classification(n_samples=n, \n", " n_features=num_feature, \n", " random_state=42,\n", " chunks=n // 10\n", " )\n", " dtrain = xgb.DMatrix(X, y)\n", " t1 = time.time()\n", " xgb.train(\n", " {\"tree_method\": \"hist\", \"objective\": \"binary:hinge\"},\n", " dtrain,\n", " num_boost_round=4,\n", " evals=[(dtrain, \"train\")],\n", " verbose_eval=False,\n", " )\n", " xgb_timings.append(('xgboost', n, time.time() - t1))\n", " dtrain_dask = xgb.dask.DaskDMatrix(client, X, y)\n", " t1 = time.time()\n", " xgb.dask.train(\n", " client,\n", " {\"tree_method\": \"hist\", \"objective\": \"binary:hinge\"},\n", " dtrain_dask,\n", " num_boost_round=4,\n", " evals=[(dtrain_dask, \"train\")],\n", " verbose_eval=False,\n", " )\n", " xgb_timings.append(('dask-ml', n, time.time() - t1))\n", "\n", "df = pd.DataFrame(xgb_timings, columns=['method', '# of samples', 'time'])\n", "sns.barplot(data=df, x='# of samples', y='time', hue='method')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If using XGBoost's scikit-learn-style API, you need to change [`xgboost.XGBClassifier`](https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBClassifier) to [`xgboost.dask.DaskXGBClassifier`](https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.dask.DaskXGBClassifier) or [`xgboost.XGBRegressor`](https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.XGBRegressor) to [`xgboost.dask.DaskXGBRegressor`](https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.dask.DaskXGBRegressor).\n", "\n", "### Distributed GPU Training\n", "\n", "Dask can manage multiple GPUs, and XGBoost can perform multi-GPU training based on Dask. We need to install Dask-CUDA to launch a multi-GPU Dask cluster. Dask can distribute XGBoost training across several GPU devices." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/distributed/deploy/spec.py:324: UserWarning: Port 8787 is already in use.\n", "Perhaps you already have a cluster running?\n", "Hosting the HTTP server on port 44607 instead\n", " self.scheduler = cls(**self.scheduler_spec.get(\"options\", {}))\n" ] }, { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Client

\n", "

Client-7c3ce804-0aef-11ef-98d2-000012e4fe80

\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "
Connection method: Cluster objectCluster type: dask_cuda.LocalCUDACluster
\n", " Dashboard: http://127.0.0.1:44607/status\n", "
\n", "\n", " \n", "\n", " \n", "
\n", "

Cluster Info

\n", "
\n", "
\n", "
\n", "
\n", "

LocalCUDACluster

\n", "

e461dd92

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", "
\n", " Dashboard: http://127.0.0.1:44607/status\n", " \n", " Workers: 4\n", "
\n", " Total threads: 4\n", " \n", " Total memory: 90.00 GiB\n", "
Status: runningUsing processes: True
\n", "\n", "
\n", " \n", "

Scheduler Info

\n", "
\n", "\n", "
\n", "
\n", "
\n", "
\n", "

Scheduler

\n", "

Scheduler-a6b71eff-839c-4686-9316-a886dc1da17a

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " Comm: tcp://127.0.0.1:33619\n", " \n", " Workers: 4\n", "
\n", " Dashboard: http://127.0.0.1:44607/status\n", " \n", " Total threads: 4\n", "
\n", " Started: Just now\n", " \n", " Total memory: 90.00 GiB\n", "
\n", "
\n", "
\n", "\n", "
\n", " \n", "

Workers

\n", "
\n", "\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 0

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:45305\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:46261/status\n", " \n", " Memory: 22.50 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:37589\n", "
\n", " Local directory: /tmp/dask-scratch-space/worker-qo8pr3rx\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 1

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:38835\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:38961/status\n", " \n", " Memory: 22.50 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:40985\n", "
\n", " Local directory: /tmp/dask-scratch-space/worker-vhjea3dv\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 2

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:46315\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:42153/status\n", " \n", " Memory: 22.50 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:39945\n", "
\n", " Local directory: /tmp/dask-scratch-space/worker-5uebhi4w\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 3

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:38331\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:42005/status\n", " \n", " Memory: 22.50 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:42591\n", "
\n", " Local directory: /tmp/dask-scratch-space/worker-zylz7yva\n", "
\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
\n", "\n", "
\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from dask_cuda import LocalCUDACluster\n", "import xgboost as xgb\n", "client = Client(LocalCUDACluster())\n", "client" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "clf = xgb.dask.DaskXGBClassifier(verbosity=1)\n", "clf.set_params(tree_method=\"hist\", device=\"cuda\")\n", "clf.client = client" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/fs/fast/u20200002/envs/dispy/lib/python3.11/site-packages/dask_ml/datasets.py:373: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.\n", " informative_idx, beta = dask.compute(\n", "[23:01:19] task [xgboost.dask-0]:tcp://127.0.0.1:45305 got new rank 0\n", "[23:01:19] task [xgboost.dask-1]:tcp://127.0.0.1:38835 got new rank 1\n", "[23:01:19] task [xgboost.dask-2]:tcp://127.0.0.1:46315 got new rank 2\n", "[23:01:20] task [xgboost.dask-3]:tcp://127.0.0.1:38331 got new rank 3\n" ] } ], "source": [ "X, y = dask_ml.datasets.make_classification(n_samples=100_000, \n", " n_features=1_000, \n", " random_state=42,\n", " chunks=100_000 // 100\n", ")\n", "clf.fit(X, y, eval_set=[(X, y)], verbose=False)\n", "prediction = clf.predict(X)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 2 }