9.1. Ray Train#
Ray Train uses Ray’s Actor and Task to support the machine learning and deep learning training processes, and implements the horizontal expansion of single-machine tasks. In short, in Ray Train , each Actor has an independent copy of the machine learning model and can complete the training task independently. Using the horizontal expansion capability of Actors, Ray Train enables training tasks to be expanded on Ray clusters.
Ray Train encapsulates common machine learning libraries such as PyTorch, PyTorch Lightning, HuggingFace Transformers, XGBoost, LightGBM, and provides interfaces to users. Users do not need to write Actor code, and only need to make few modifications to the original single-machine machine learning workflow to quickly switch to cluster mode. Taking PyTorch as an example, this section describes how to achieve horizontal expansion of training tasks based on data parallelism. For details on the principle of data parallelism, see sec-data-parallel
.
Key steps#
To modify a PyTorch stand-alone training code to run on Ray Train, the following changes need to be made:
Define
train_loop
, which is a single-node training function, including loading data and updating parameters.Define
ScalingConfig
, which defines how to scale this training job horizontally, including how many computing nodes are needed, whether to use GPU, etc.Define
Trainer
, gluetrain_loop
andScalingConfig
together, and then execute theTrainer.fit()
method for training.
Fig. 9.1 shows the key parts of adapting Ray Train.
The codes mainly include:
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
def train_loop():
...
scaling_config = ScalingConfig(num_workers=..., use_gpu=...)
trainer = TorchTrainer(train_loop_per_worker=train_loop, scaling_config=scaling_config)
result = trainer.fit()
Example: Image Classification#
Below is a complete training example. This example uses the ResNet model provided by PyTorch [He et al., 2016]. Readers can set ScalingConfig
based on the number of GPUs in their environment.
import os
import tempfile
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import ray
import ray.train.torch
from ray.train import Checkpoint
def train_func(model, optimizer, criterion, train_loader):
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.train()
for data, target in train_loader:
# No need to manually send images and labels to a specific GPU
# `prepare_data_loader` helps with this process
# data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def test_func(model, data_loader):
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in data_loader:
# data, target = data.to(device), target.to(device)
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
return correct / total
data_dir = os.path.join(os.getcwd(), "../data")
def train_loop():
# Load data and perform data augmentation
transform = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,))]
)
train_loader = DataLoader(
torchvision.datasets.FashionMNIST(root=data_dir, train=True, download=True, transform=transform),
batch_size=128,
shuffle=True)
test_loader = DataLoader(
torchvision.datasets.FashionMNIST(root=data_dir, train=False, download=True, transform=transform),
batch_size=128,
shuffle=True)
# 1. Distribute data to multiple computing nodes
train_loader = ray.train.torch.prepare_data_loader(train_loader)
test_loader = ray.train.torch.prepare_data_loader(test_loader)
# The original resnet is designed for 3-channel images
# FashionMNIST is 1 channel, modify the first layer of resnet to adapt to this input
model = resnet18(num_classes=10)
model.conv1 = torch.nn.Conv2d(
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
)
# 2. Distribute the model to multiple computing nodes and GPUs
model = ray.train.torch.prepare_model(model)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Train for 10 epochs
for epoch in range(10):
if ray.train.get_context().get_world_size() > 1:
train_loader.sampler.set_epoch(epoch)
train_func(model, optimizer, criterion, train_loader)
acc = test_func(model, test_loader)
# 3. Monitor training metrics and save checkpoints
metrics = {"acc": acc, "epoch": epoch}
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
torch.save(
model.state_dict(),
os.path.join(temp_checkpoint_dir, "model.pt")
)
ray.train.report(
metrics,
checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
)
if ray.train.get_context().get_world_rank() == 0:
print(metrics)
# 4. Configure `ScalingConfig`, Ray Train will expand the training task to the cluster according to this configuration
scaling_config = ray.train.ScalingConfig(num_workers=4, use_gpu=True)
# 5. Start parallel training using TorchTrainer
trainer = ray.train.torch.TorchTrainer(
train_loop_per_worker=train_loop,
scaling_config=scaling_config,
run_config=ray.train.RunConfig(
storage_path=os.path.join(data_dir, "torch_ckpt"),
name="exp_fashionmnist_resnet18",
)
)
result = trainer.fit()
Show code cell output
Tune Status
Current time: | 2024-04-10 09:41:32 |
Running for: | 00:01:33.99 |
Memory: | 31.5/90.0 GiB |
System Info
Using FIFO scheduling algorithm.Logical resource usage: 1.0/64 CPUs, 4.0/4 GPUs (0.0/1.0 accelerator_type:TITAN)
Trial Status
Trial name | status | loc | iter | total time (s) | acc | epoch |
---|---|---|---|---|---|---|
TorchTrainer_3d3d1_00000 | TERMINATED | 10.0.0.3:49324 | 10 | 80.9687 | 0.8976 | 9 |
(RayTrainWorker pid=49399) Setting up process group for: env:// [rank=0, world_size=4]
(RayTrainWorker pid=49400) [W Utils.hpp:133] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function getCvarInt)
(TorchTrainer pid=49324) Started distributed worker processes:
(TorchTrainer pid=49324) - (ip=10.0.0.3, pid=49399) world_rank=0, local_rank=0, node_rank=0
(TorchTrainer pid=49324) - (ip=10.0.0.3, pid=49400) world_rank=1, local_rank=1, node_rank=0
(TorchTrainer pid=49324) - (ip=10.0.0.3, pid=49401) world_rank=2, local_rank=2, node_rank=0
(TorchTrainer pid=49324) - (ip=10.0.0.3, pid=49402) world_rank=3, local_rank=3, node_rank=0
(RayTrainWorker pid=49399) Moving model to device: cuda:0
(RayTrainWorker pid=49399) Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=49401) [rank2]:[W Utils.hpp:106] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function getCvarString)
(RayTrainWorker pid=49400) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000000)
(RayTrainWorker pid=49402) [W Utils.hpp:133] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function getCvarInt) [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
(RayTrainWorker pid=49402) [rank3]:[W Utils.hpp:106] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function getCvarString) [repeated 3x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.8604, 'epoch': 0}
(RayTrainWorker pid=49399) {'acc': 0.8808, 'epoch': 1}
(RayTrainWorker pid=49399) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000001) [repeated 4x across cluster]
(RayTrainWorker pid=49399) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000002) [repeated 4x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.8852, 'epoch': 2}
(RayTrainWorker pid=49399) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000003) [repeated 4x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.8964, 'epoch': 3}
(RayTrainWorker pid=49399) {'acc': 0.8972, 'epoch': 4}
(RayTrainWorker pid=49401) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000004) [repeated 4x across cluster]
(RayTrainWorker pid=49399) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000005) [repeated 4x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.8968, 'epoch': 5}
(RayTrainWorker pid=49401) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000006) [repeated 4x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.8948, 'epoch': 6}
(RayTrainWorker pid=49399) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000007) [repeated 4x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.894, 'epoch': 7}
(RayTrainWorker pid=49401) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000008) [repeated 4x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.894, 'epoch': 8}
(RayTrainWorker pid=49401) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name/TorchTrainer_3d3d1_00000_0_2024-04-10_09-39-58/checkpoint_000009) [repeated 4x across cluster]
(RayTrainWorker pid=49399) {'acc': 0.8976, 'epoch': 9}
2024-04-10 09:41:32,109 WARNING experiment_state.py:205 -- Experiment state snapshotting has been triggered multiple times in the last 5.0 seconds. A snapshot is forced if `CheckpointConfig(num_to_keep)` is set, and a trial has checkpointed >= `num_to_keep` times since the last snapshot.
You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
You can suppress this error by setting the environment variable TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S to a smaller value than the current threshold (5.0).
2024-04-10 09:41:32,112 INFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/home/u20200002/distributed-python/ch-ray-train-tune/../data/torch_ckpt/experiment_name' in 0.0057s.
2024-04-10 09:41:32,120 INFO tune.py:1048 -- Total run time: 94.05 seconds (93.99 seconds for the tuning loop).
Difference from native PyTorch#
Difference from stand-alone program#
Ray Train helps users distribute models and data to multiple computing nodes. Users need to set: model = ray.train.torch.prepare_model(model)
and train_loader = ray.train.torch.prepare_data_loader(train_loader)
. After setting, Ray Train does not need to explicitly call model.to("cuda")
, nor does it need images, labels = images.to("cuda"), labels.to("cuda")
and other codes to copy model data to GPU.
Difference from DistributedDataParallel
#
PyTorch’s DistributedDataParallel
can also implement data parallelism. Ray Train hides the complex details in DistributedDataParallel
, and only requires users to make slight changes from the stand-alone code. The distributed environment (World) and process (Rank) of torch.distributed
are not needed. For concepts such as World and Rank, please refer to sec-mpi-hello-world
.
Data reading#
If the stand-alone version of data reading is based on PyTorch’s DataLoader
, you can use ray.train.torch.prepare_data_loader()
to adapt the original PyTorch DataLoader
. You can also use the data preprocessing method provided by Ray Data for data preprocessing.
ScalingConfig
#
The num_workers
parameter in ScalingConfig(num_workers=..., use_gpu=...)
is used to control the parallelism of the task, and the use_gpu
parameter is used to control whether GPU resources are used. num_workers
can be understood as the number of Ray Actors started, each of which performs training tasks independently. If use_gpu=True
, by default, each Actor will be assigned 1 GPU, and accordingly, the environment variable CUDA_VISIBLE_DEVICES
for each Actor is also 1. To enable each Actor to access multiple GPUs, you can set the resources_per_worker
parameter: resources_per_worker={"GPU": n}
.
Monitoring#
In distributed training, each Worker runs independently, but in most cases, you only need to monitor the first process with a process number (Rank) of 0. ray.train.report(metrics=...)
collects metrics for Rank=0 by default.
Checkpoint#
The Checkpoint process is as follows:
Checkpoint will be written to a local directory first. You can directly use the model saving interface provided by PyTorch, PyTorch Lightning or TensorFlow. For example, in the example above:
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
torch.save(
model.module.state_dict(),
os.path.join(temp_checkpoint_dir, "model.pt")
)
When
ray.train.report(metrics=..., checkpoint=...)
is called, the newly saved local checkpoint is uploaded to a persistent file system (e.g., S3 or HDFS), which is accessible to all compute nodes. The local checkpoint is just a cache. After the checkpoint is uploaded to the persistent file system, the local checkpoint will be deleted. The persistent file system directory is configured onTorchTrainer
:
TorchTrainer(
train_loop,
scaling_config=scaling_config,
run_config=ray.train.RunConfig(
storage_path=...,
name="experiment_name",
)
)
When using data parallel training, each rank has a copy of the model weights, which are saved locally and are the same as the checkpoints on the persistent file system. When using other parallel strategies such as pipeline parallel training (sec-pipeline-parallel
), each rank saves part of the model locally, and each rank saves its own part of the model weights. When generating checkpoint files, some file prefixes and suffixes should be added to distinguish them.
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
rank = train.get_context().get_world_rank()
torch.save(
...,
os.path.join(temp_checkpoint_dir, f"model-rank={rank}.pt"),
)
train.report(
metrics,
checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
)