Checkpointing Deep Dive: Current Limitations And Future Plans
Hey guys! Let's dive into the nitty-gritty of checkpointing in our systems, specifically focusing on why we can't easily save everything just yet and what we're cooking up for the future. This is a critical topic for ensuring the resilience and reproducibility of our machine learning models, so let's get into it.
TL;DR: The Current State of Checkpointing
We can't cleanly checkpoint everything, like the dataloader, replay buffer, random number generator (RNG), and other crucial components, because of how things are currently structured. The checkpointer in the Titan system lives inside the trainer.engine
and is solely responsible for managing the step-<N>
folders. The other components I mentioned exist in separate actors. There's currently no safe, centralized method to write all the data into the same step folder. So, we're putting off full multi-component checkpointing until after PTC (I’ll explain this later). The good news is, thanks to this pull request, we can already save and resume model weights, optimizer states, and learning rate schedulers using Titan's checkpointer.
Context: How Checkpointing Works Today
Let’s start with how things work currently. Our system spins up various components in the main
script:
(
dataloader, # DatasetActor (actor)
policy, # Policy (service)
trainer, # RLTrainer (actor)
replay_buffer, # ReplayBuffer (actor)
compute_advantages, # ComputeAdvantages (actor)
ref_model, # ReferenceModel (service)
reward_actor, # RewardActor (service)
) = await asyncio.gather(...)
The model checkpointing is handled by the TorchTitan library, specifically through the trainer
's engine. The trainer
is responsible for creating the engine and loading any existing checkpoints:
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load(step=self.step)
During each training step (inside the train_step
function), the trainer
instructs Titan to save the necessary information:
self.engine.checkpointer.save(
curr_step=self.step,
last_step=self.step == self.num_training_steps,
)
Titan's checkpointer then writes the model weights, optimizer state, and learning rate schedulers into a folder structure that looks like this:
<folder>/step-<N>/__0_0.distcp
We also want to save and load other essential states, like the data step, replay buffer data, and RNG states. These additions will allow us to fully restore the training process from any point, improving the robustness of our models. This is where the core challenge lies: the current architecture doesn't naturally support saving all these different components together.
The Problem: Obstacles to Comprehensive Checkpointing
So, what's the deal? Why can't we just save everything? The problem boils down to a couple of key issues.
Problem 1: Step-Folder Ownership
Currently, we have one Titan-owned directory per step (e.g., step-200
), which is created internally within the trainer.engine.checkpointer
. Other actors, like the dataloader and replay buffer, don't have access to the trainer's engine or the internal folder naming conventions of Titan. This setup leads to two not-so-great choices.
-
Two folders per step:
checkpoint/ step-200/ # Titan __0_0.distcp step-200-other/ # Ours dataloader.json replay_buffer.bin rng.json
This option is clunky. It creates a messy user experience, makes it difficult to atomically purge or retain checkpoints, and is prone to errors. You can easily get the components out of sync.
-
Single folder per step (the preferred option):
To save everything in the same
step-200/
folder, we'd have to jump through some hoops:- Call Titan's private
_create_checkpoint_id
to get the folder name. However, other components (like thedataloader
) don't have access to theengine
. It is a big NO-NO. - Reimplement a similar function and hope it never deviates from the original. This is also not a good approach.
- (Preferred Solution) Add a
path
parameter to thecheckpointer.save
function. This will allow us to specify the save location, giving us more control. This is the most promising option.
- Call Titan's private
Problem 2: Lack of a Unified Saving Mechanism
Currently, the different states are managed by separate actors or services (e.g., the dataloader
), and the trainer
isn't aware of them. Moreover, Titan's checkpointer resides within trainer.engine.checkpointer
and is only responsible for the model, optimizer, and learning rate scheduler. There's no central place to coordinate the saving of all these components into the same step-<N>
directory. This is a critical problem.
Proposed Solutions: How We Can Move Forward
Now, let's explore some solutions to address these problems.
Option 1: Trainer as the Central Owner
class RLTrainer:
self.dataloader = ...
self.replay_buffer = ...
self.rng = ...
This would mean that the trainer
would own all the other components. It's a quick fix in terms of implementation, but it causes tight coupling, violates the actor/service separation, and ultimately hurts scalability and reusability. It's not a great long-term solution.
Option 2: Reimplement Checkpointing
We could create our own model, optimizer, and learning rate scheduler checkpointing from scratch. This would give us full control over the layout and atomicity of the process. However, this is a risky, high-effort task, and we're guaranteed to diverge from Titan over time. It's not the best approach.
Option 3: Checkpoint Coordinator
This solution involves introducing a Checkpoint Coordinator. The coordinator would sit above the existing actors and handle the checkpointing process. Here’s how it would work:
- The coordinator calls Titan to save the model, learning rate scheduler, and optimizers to a specified path (this will require a small API update to Titan's
save
function to accept apath
parameter). - The coordinator then asks each actor to create a
state_dict()
and writes it to the same folder (e.g.,step-200/dataloader.json
, etc.). - On loading, after Titan resolves the step to load, the coordinator attempts to load each component's states by calling their
load_state()
function.
class CheckpointCoordinator:
def __init__(self):
self._trainer: RLTrainer = None
self._components: Dict[str, ForgeActor | SeverceInterface] = {}
def set_trainer(self, trainer:RLTrainer):
self._trainer = trainer
def register(self, name, comp: ForgeActor | SeverceInterface):
self._components[name] = comp
async def save(self, step: int, path: str):
path = get_path(folder, step)
if self._trainer:
self._trainer.engine.checkpointer.save(path = path)
for name, comp in self._components.items():
states = comp.state_dict()
save(states, f"{path}/{name}.json")
async def load(self, step: int, path: str):
...
The changes needed in grpo/main
would look like this:
coord = CheckpointCoordinator()
coord.set_trainer(trainer)
coord.register("dataloader", dataloader)
coord.register("replay_buffer", replay_buffer)
...
await coord.load(step, path=step_dir)
await coord.save(step, path=step_dir)
This option is relatively easy to implement and leverages much of the existing infrastructure. However, it still has some drawbacks:
- It has a nested structure:
coord.save
callsself._trainer.engine.checkpointer.save
. - It is slightly specific to our current
grpo
script.
Option 4: Standalone ForgeCheckpointManager
In the long run, we could create a standalone manager ForgeCheckpointManager
that inherits from Titan's CheckpointManager
. This manager would orchestrate both Titan and any additional components in a single save()/load()
call. Actors would register their export_state/import_state
with this manager, and main
would only need to call this single manager. This is a more elegant and scalable solution.
- Open Question: Where would the
ForgeCheckpointManager
reside if the engine remains within the trainer? Also, how can it read/write the model, optimizer, and learning rate scheduler states without re-nesting the trainer or breaking the decoupling of actors?
That's the plan, folks! We're actively working on improving our checkpointing capabilities to make our models even more robust and reliable. Stay tuned for updates as we continue to refine this process.