This post was originally published on this site

As foundation model training infrastructure scales to tens of thousands of accelerators, efficient utilization of those high-value resources becomes paramount. In particular, as the cluster gets larger, hardware failures become more frequent (~ few hours) and recovery from previously saved checkpoints becomes slower (up to 30 minutes), significantly slowing down training progress. A checkpoint represents the saved state of a model’s training progress at any given time and consists of a set of intermediary model weights and other parameters. 

1

We recently introduced multi-tier checkpointing in AI Hypercomputer, our integrated supercomputing system that incorporates lessons from more than a decade of Google’s expertise in AI. This solution increases the ML Goodput of large training jobs (e.g. by 6.59% in a 35K-chip workload on TPU v5p) by utilizing multiple tiers of storage, including in-cluster memory (RAM) and replication, and Google Cloud Storage, thereby minimizing lost progress during a training job and improving mean-time-to-recovery (MTTR). This solution is compatible with JAX using MaxText as a reference architecture as well as NeMo with PyTorch / GPUs.

2

Multi-tier checkpointing architecture: checkpoints are stored in (1) each node’s RAM, (2) in a different slice or superblock, and (3) in Cloud Storage.

What this means is that you can take a checkpoint at the most optimal frequency (the checkpoint save scales sub-linearly to < 5 minutes) for the biggest models and across a very large node cluster and restore in under a minute across a cluster with thousands of nodes.  

Increases in Goodput can translate directly to decreases in infrastructure costs. For example, consider the case where you are using accelerator chips to train a model that takes one month to complete. Even with a somewhat smaller training workload, the cost savings with optimal checkpointing can be significant. If you have a week-long training job spanning 1K VMs that cost $88/hour (a3-highgpu-8g), a 6.5% increase in Goodput on this training task could result in almost $1M in infrastructure savings.

More failures require more checkpointing

Probabilistically, the mean time between failure (MTBF) of a training job decreases — failures happen more frequently — as the size of the cluster increases. Therefore, it is important that foundation model producers take checkpoints more frequently so they don’t lose too much progress on their training job. In the past, Google Kubernetes Engine (GKE) customers could only write a checkpoint every 30 minutes (saving it to Cloud Storage) and had to wait up to 30 minutes to read the last saved checkpoint and distribute it to all the nodes in the cluster. 

image2

Multi-tier checkpointing allows for much faster checkpoint writes and more frequent saves by writing data asynchronously to memory (RAM) on the node and then periodically replicating this data inside the cluster, and backing that data up to Cloud Storage. In the event of a failure, a job’s progress can be recovered quickly by using data from a nearby neighbor’s in-memory checkpoint. If the checkpoint data isn’t available in a nearby node’s RAM, checkpoints are downloaded from Cloud Storage bucket backups. With this solution, checkpoint write latency does not increase with the number of nodes in a cluster — it remains constant. Reads are also constant and scale independently, enabling faster checkpoint loading and reducing MTTR. 

Architectural details

4

Conceptually, the multi-tier checkpointing solution provides a single “magic” local filesystem volume for ML training jobs to use for saving checkpoints and from which to restore. It’s “magic” because while it provides ramdisk-level read/write speeds, it also provides data durability associated with Cloud Storage.

When enabled, local volume (Node storage) is the only storage tier visible to ML jobs. The checkpoints written there are automatically replicated in-cluster to one/two/or more peer nodes and are regularly backed up to Cloud Storage.

When the job restarts, the checkpoint data specific for the new portion of the training job running on the node (i.e., NodeRank) automatically appears on the local volume for ML jobs to use. Behind the scenes, the necessary data may be fetched from another node in the cluster, or from Cloud Storage. Finding the most recent fully written checkpoint (no matter where it is) also happens transparently for ML jobs.

The component responsible for data movement across tiers is called Replicator and is running on every Node as a part of a CSI driver that provides local volume mount.

Delving deeper, the Replicator performs the following critical functions:

  • Centralized intelligence: It analyzes Cloud Storage backups and the collective in-cluster data to determine the most recent, complete checkpoint with which to restore a job upon restart. Furthermore, it detects successful checkpoint saves by all nodes, signaling when older data can be safely garbage-collected, and strategically decides which checkpoints to back up to Cloud Storage.

  • Smart peer selection: Because it’s aware of the underlying network topology used by both TPUs and GPUs, the Replicator employs smart criteria to select replication peers for each node. This involves prioritizing a “near” peer with high bandwidth and low latency. This “near” peer may have a potentially higher risk of correlated failure (e.g., within the same TPU Slice or GPU Superblock) and as such, it also selects a “far” peer — one with slightly increased networking overhead but enhanced resilience to independent failures (e.g., that resides in a different GPU Superblock). In data parallelism scenarios, preference is given to any peers that possess identical data.

  • Automatic data deduplication: When data parallelism is employed, multiple nodes run identical training pipelines, resulting in the saving of identical checkpoints. The Replicator’s peer selection ensures these nodes are paired, eliminating the need for actual data replication. Instead, each node verifies the data integrity of its peers; no additional bandwidth is consumed, replication is instantaneous, and local storage usage is significantly reduced. If peers are misconfigured, standard checkpoint copying is maintained.

  • Huge-model mode with data parallelism assumption: Beyond optimization, this mode caters to the largest models, where local node storage is insufficient to house both a node’s own checkpoint as well as a peer’s data. In such cases, the ML job configures the Replicator to assume data parallelism, drastically reducing local storage requirements. This extends to scenarios where dedicated nodes handle Cloud Storage backups rather than the nodes storing the most recent checkpoints themselves.

  • Optimized Cloud Storage utilization: Leveraging data deduplication, all unique data is stored in Cloud Storage only once, optimizing storage space, bandwidth consumption, and associated costs.

  • Automated garbage collection: The Replicator continuously monitors checkpoint saves across all nodes. Once the latest checkpoint is confirmed to have been successfully saved everywhere, it automatically initiates the deletion of older checkpoints, while ensuring that checkpoints still being backed up to Cloud Storage are retained until the process is complete.

A wide range of checkpointing solutions

At Google Cloud, we offer a comprehensive portfolio of checkpointing solutions to meet diverse AI training needs. Options like direct Cloud Storage and Cloud Storage FUSE are simpler approaches and serve smaller to medium-scale workloads very effectively. Parallel file systems such as Lustre offer high throughput for large clusters, while multi-tier checkpointing is purpose-built for the most demanding, highest-scale (>1K nodes) training jobs that require very frequent saves and rapid recovery.

Multi-tier checkpointing is currently in preview, focused on JAX for Cloud TPUs and PyTorch on GPUs. Get started with it today by following our user guide, and don’t hesitate to reach out to your account team if you have any questions or feedback.