Path: blob/main/intermediate_source/monarch_distributed_tutorial.rst
2099 views
========================================================== Interactive Distributed Applications with Monarch ========================================================== **Author**: `Amir Afzali <https://github.com/amirafzali>`_ Introduction ------------ As deep learning models continue to grow in size and complexity, training them efficiently requires coordinating computation across multiple GPUs and nodes. In this tutorial, you will learn how to easily set up and run large-scale distributed workflows using Monarch's actor framework together with TorchTitan, on a SLURM-managed cluster. Monarch will allow us to drive a large cluster of machines (organized into a mesh), as if developing on a single host, single process environment. What is Monarch? ^^^^^^^^^^^^^^^^ Monarch is an actor framework designed to streamline the development of distributed applications. At its core, Monarch provides: - **Actor-based programming model**: Encapsulate stateful computations in actors that can run on remote processes and machines - **Process mesh abstractions**: Easily manage and coordinate distributed processes across your cluster, with scalable Actor messaging - **Fault tolerance**: Actors and processes form a tree and failures propagate up the tree, providing good default error behavior and enabling fine-grained fault recovery. - **Flexible resource management**: Support for multiple cluster schedulers including SLURM, Kubernetes, custom host management, and local processes - **Integrated monitoring**: Stream logs from remote processes back to your client for easy debugging and aggregation For more details, see the `Monarch documentation <https://meta-pytorch.org/monarch/generated/examples/getting_started.html>`_. Why Use Monarch? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TorchTitan is a PyTorch native library for pre-training at scale. While TorchTitan provides excellent primitives for distributed training, launching and managing these jobs across clusters can slow down iteration. Monarch addresses this with: 1. **Simplified cluster interaction**: Reserve and manage compute resources with simple async Python calls instead of writing bash scripts 2. **Interactive development**: Modify and re-run training code on existing allocations without waiting for new resources 3. **Unified workflow**: Seamlessly move between local testing and cluster execution with the same code Prerequisites ------------- We rely on a nightly build of Titan for this tutorial, so please ensure that other Torch libraries are tracking nightly builds: 1. **Monarch nightly installed:** `Install script <https://github.com/meta-pytorch/monarch/blob/main/scripts/install_nightly.py>`_ 2. **TorchTitan nightly installed:** `TorchTitan install instructions <https://github.com/pytorch/torchtitan?tab=readme-ov-file#nightly-builds>`_ 3. **A valid Titan model config** and **tokenizer** in your working directory (e.g., ``debug_model.toml`` from `TorchTitan configs <https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/train_configs/debug_model.toml>`_). 4. **SLURM cluster access:** - Sufficient permissions to reserve nodes and launch jobs. - CUDA environment configured for distributed GPU training. Now let's implement this step by step! Step 1: Reserve Machine Resources --------------------------------- First, we'll define a function to programmatically reserve a machine allocation. **Monarch Highlight**: Instead of submitting an SBATCH script, you can reserve and manage resources interactively from Python. The JobTrait design pattern allows for interfacing with custom schedulers, such as SLURM and Kubernetes, through a consistent API. .. code-block:: python from monarch.job import SlurmJob, JobTrait def create_slurm_job( mesh_name: str, num_nodes: int, gpus_per_node: int, time_limit: str = "06:00:00" ) -> SlurmJob: """ Args: mesh_name: Name assigned to the primary mesh for this example. A JobTrait can consist of multiple meshes, and Monarch allows for re-attaching to ongoing jobs. num_nodes: Number of nodes allocated per mesh gpus_per_node: Number of GPUs per node in the mesh Note: SlurmJob is just one instance of a Monarch scheduler interface. Consult the JobTrait documentation to find one that's right for your usecase. """ default_job_name = "monarch_titan" return SlurmJob( meshes={mesh_name: num_nodes}, job_name=default_job_name, time_limit=time_limit, gpus_per_nodes=gpus_per_node, # ... additional args can be passed here ) Step 2: Define the Trainer Actor -------------------------------- Now we create a Monarch Actor that wraps TorchTitan's Trainer. This is the key abstraction that allows TorchTitan to run in Monarch's distributed environment. **Monarch Highlight**: The Actor pattern provides several benefits: 1. **Remote execution**: Methods marked with @endpoint can be called remotely 2. **Lifecycle management**: Monarch handles initialization, execution, and cleanup 3. **Error handling**: Exceptions are properly propagated back to the client, enabling progressive error handling .. code-block:: python import torch from monarch.actor import Actor, current_rank, endpoint from monarch.utils import setup_env_for_distributed from torchtitan.tools.logging import init_logger, logger from torchtitan.train import Trainer class TrainerActor(Actor): """ Monarch Actor wrapper for TorchTitan's Trainer. This actor encapsulates a complete TorchTitan training process, handling initialization, training loop execution, and cleanup. Each instance runs on a single GPU in the distributed training job. The actor's lifetime: 1. __init__: Initialize with job configuration 2. start_training: Execute the training loop Destroy process group and release resources Attributes: job_config: TorchTitan configuration for this trainer uid: Unique identifier for logging (includes rank) """ def __init__(self, job_config: "JobConfig") -> None: """ Initialize the trainer actor. Args: job_config: TorchTitan JobConfig with training parameters """ self.job_config = job_config # current_rank() provides access to this actor's rank in the process mesh self.rank = current_rank().rank self.uid = f"[trainer_{rank}]" @endpoint async def ping_rank(self) -> None: """ A dummy logging function we will use for demonstration purposes. """ logger.info(f"{self.uid} Ping!") @endpoint async def start_training(self) -> None: """ Execute the TorchTitan training loop. This remote endpoint: 1. Initializes TorchTitan's logger 2. Creates a Trainer instance with the job configuration 3. Runs the training loop 4. Handles cleanup and error conditions The @endpoint decorator makes this method callable from the Monarch client, even though it runs on a remote GPU node. Raises: Exception: Any exception from TorchTitan training is propagated back to the client """ init_logger() trainer: Trainer | None = None try: # Initialize TorchTitan trainer trainer = Trainer(self.job_config) logger.info(f"{self.uid} initialized successfully and starting training") # Run the training loop trainer.train() except Exception as e: logger.error(f"{self.uid} training failed: {e}") if trainer: trainer.close() # Note: error is propagated back to the controller raise e else: # Training completed successfully trainer.close() logger.info(f"{self.uid} training completed successfully") finally: # Clean up distributed process group torch.distributed.destroy_process_group() logger.info(f"{self.uid} trainer cleaned up") Actor endpoints can be invoked in a variety of patterns. We'll explore a concrete example in `Step 4: Execute the Training Workflow`_, but here is some pseudocode with common usages: .. code-block:: python try: # where mesh0 is made of N nodes, each node having 8 GPUs proc_mesh = mesh0.spawn_procs({"gpus": 8}) trainer_actors = proc_mesh.spawn("trainers", TrainerActor, ...) # Call on all ranks await trainer_actors.ping_rank.call() # Call-and-forget on all ranks trainer_actors.ping_rank.broadcast() # Call on ONE random rank await trainer_actors.ping_rank.choose() # Call on the first 3 ranks of node 0 await trainer_actors.slice(hosts=0, gpus=slice(0, 3)).ping_rank.call() except Exception as e: # handle SupervisionEvents from remote actor failures pass Remote actor endpoints can also utilize Python native breakpoints, enabling interactive debugging sessions. For a complete deep-dive into Monarch debuggers, please `refer to the documentation <https://meta-pytorch.org/monarch/generated/examples/debugging.html>`_. .. code-block:: python @endpoint async def ping_debuggable_rank(self) -> None: logger.info(f"{self.uid} Ping!") if self.rank == 0: breakpoint() logger.info(f"{self.uid} Pong!") Step 3: Define Training Parameters ----------------------------------- Next, we define some common parameters for our training job and cluster resources. This configuration determines both the scale of training (number of nodes and GPUs), and some of the training hyperparameters. .. code-block:: python from dataclasses import dataclass @dataclass class RunParams: """ Configuration for cluster resources and training parameters. Attributes: training_steps: Number of training iterations to run model_config: Path to TorchTitan model configuration file tokenizer: Path to tokenizer directory dataset: Dataset to use for training (e.g., 'c4', 'c4_test') num_nodes: Number of compute nodes to request gpus_per_node: Number of GPUs per node Adjust these values based on your model size and available resources. """ training_steps: int = 50 model_config: str = "debug_model.toml" tokenizer: str = "tokenizer" dataset: str = "c4" num_nodes: int = 2 gpus_per_node: int = 8 TorchTitan uses a JobConfig object to control all aspects of training. Here we create a function that parses this configuration from our RunParams. .. code-block:: python import os from torchtitan.config import ConfigManager, JobConfig def make_job_config() -> JobConfig: """ Create a TorchTitan JobConfig from RunParams. This function constructs the complete training configuration, including parallelism settings, model architecture, and dataset paths """ # Calculate total parallelism based on cluster size data_parallel_shard_degree = RunParams.num_nodes * RunParams.gpus_per_node output_path = "./outputs" # Construct paths relative to script directory script_dir = os.getcwd() # Build argument list for TorchTitan's ConfigManager # These override defaults from the model config file default_args = [ "--job.config_file", os.path.join(script_dir, RunParams.model_config), "--model.tokenizer_path", os.path.join(script_dir, RunParams.tokenizer), "--parallelism.data_parallel_shard_degree", str(data_parallel_shard_degree), "--training.steps", str(RunParams.training_steps), "--training.dataset", RunParams.dataset, "--job.dump_folder", output_path, # continue to configure as needed ] config_manager = ConfigManager() job_config = config_manager.parse_args(default_args) return job_config Step 4: Execute the Training Workflow -------------------------------------- With all components defined, we now orchestrate the complete workflow. This is where Monarch's power becomes most apparent. **Monarch Highlights**: 1. **Interactive iteration**: After reserving the machine allocation, you can adjust your logic and re-spawn actors, without requesting new resources. SLURM's shared filesystem ensures that framework/workspace changes are synchronized across workers. 2. **Transparent logging**: All logs from remote workers stream back to your client in real-time, making debugging feel like local execution **Workflow**: Reserve Machines → Create Proc Mesh → Configure Logging → Spawn Actors → Train → Cleanup .. code-block:: python async def execute_training() -> None: """ Execute the complete distributed training workflow. """ job_config = make_job_config() slurm_job = None mesh_name = "mesh0" try: # 1. Create a SLURM job with N nodes # This leverages Monarch to reserve a persistent machine allocation slurm_job = create_slurm_job(mesh_name, RunParams.num_nodes, RunParams.gpus_per_node) job_state = slurm_job.state() # 2. Create a process mesh on the machine allocation # This creates one process per GPU across all allocated nodes logger.info("Creating process mesh...") proc_mesh = job_state.mesh0.spawn_procs({"gpus": RunParams.gpus_per_node}) # 3. Configure remote logging behavior # - stream_to_client: Forward all remote logs to your local console # - aggregate_window_sec: Batch logs for efficiency logger.info("Configuring logging...") await proc_mesh.logging_option( stream_to_client=True, # aggregate_window_sec=None # Uncomment to disable log batching ) # 4. Setup environment for torch.distributed # This configures torch.distributed across all processes in the mesh logger.info("Setting up distributed environment...") await setup_env_for_distributed(proc_mesh) # 5. Spawn TrainerActor on each GPU # Each process in the mesh creates its own TrainerActor instance logger.info("Spawning trainer actors...") trainer = proc_mesh.spawn( "trainer_actor", # Name for the actor group TrainerActor, # Actor class to instantiate job_config, # Arguments to __init__ ) # 6. Execute the training job across all actors # The .call() method invokes start_training() on all actors in parallel logger.info("Starting distributed training...") await trainer.start_training.call() logger.info("Training completed successfully!") except Exception as e: logger.error(f"Training workflow failed: {e}") finally: # Always clean up the machine allocation if slurm_job: await cleanup_job(slurm_job) Step 5: Clean Up Resources -------------------------- After training completes (or if you're done experimenting), it's important to free up cluster resources by terminating the SLURM job. **Monarch Highlight**: While you can keep allocations alive for multiple training runs during development, always remember to release cluster resources. .. code-block:: python async def cleanup_job(job: JobTrait) -> None: """ This function cancels the SLURM job, releasing all reserved nodes back to the cluster for other users. Args: job: A JobTrait, like the one returned from create_slurm_job() Note: The job will also terminate automatically when the configured TTL is exceeded, but explicit cleanup is recommended for long-running notebooks or scripts. """ job.kill() logger.info("Job terminated successfully") Step 6: Run the Complete Pipeline --------------------------------- Finally, we tie everything together in a main function that kicks off the workflow .. code-block:: python import asyncio if __name__ == "__main__": """ Run the complete workflow: reserve resources, train, and cleanup. """ logger.info("Starting Monarch + TorchTitan Distributed Training") asyncio.run(execute_training()) logger.info("Workflow completed!") Conclusion ----------- Congrats! In this tutorial, you learned how to apply Monarch's actor framework with TorchTitan for scalable distributed training. **Further Reading** - Monarch also integrates with TorchFT to provide per-step fault-tolerance across replicated workers. You can find a comprehensive `proof of concept <https://github.com/meta-pytorch/torchft/tree/main/examples/monarch>`_ of this integration in the TorchFT repo. - For an interactive notebook covering similar topics to this tutorial, please consult `this Monarch example <https://github.com/meta-pytorch/monarch/blob/main/examples/slurm_titan.ipynb>`_.