
Share
This article explores advanced strategies for parallelizing the Muon optimizer in large-scale machine learning models, focusing on sharding and replication techniques to enhance performance.
In the world of large-scale machine learning, optimizing performance is a constant challenge. One recent discussion on social media about the Kimi Moonlight paper led to an interesting exploration of parallelization strategies for the Muon optimizer. This article delves into the technical details of how we can effectively parallelize the Muon optimizer using different sharding and replication techniques.
Let's start with a model that is already partitioned using PP (Pipeline Parallelism), FSDP (Fully Sharded Data Parallelism), and TP (Tensor Parallelism). For our purposes, we can ignore pipeline parallelism (PP) since it doesn't affect the optimizer directly. Instead, we focus on how Fully Sharded Data Parallelism (FSDP) and Tensor Parallelism (TP) shard parameters.
Each parameter ( W ) is sharded into multiple shards ( w_1, w_2, \ldots, w_n ), which are distributed evenly across the ranks. Each parameter shard has a corresponding gradient shard ( g_i ). For simplicity, we'll also ignore small parameters like norms or modulation, as their correct analysis is more latency-driven than bandwidth or compute-driven.
Here are the key variants of parallelizing the Muon optimizer:
The simplest baseline approach is to duplicate the momentum and compute it on all GPUs. This involves the following steps:
However, this approach is inefficient because it requires:

A more efficient strategy involves sharding the update for ( M_t ) and ( W_t ), while keeping the normalization step ( O_t = \ldots ) replicated. This approach is free if the data type of ( g_i ) matches that of ( M_t ).
Let ( m_i ) be the local shard of ( M ). The steps are as follows:
To implement the NS Replication strategy, you need to ensure that:
Here’s a high-level pseudocode for the NS Replication approach:
# Pseudocode for NS Replication in PyTorch FSDP
def muon_optimizer_step(grads, momentum, weights, learning_rate, mu):
# Gather gradients from all ranks
G = allgather(grads)
# Update local momentum
m_t = mu * momentum + G
# Compute local update
x_t = mu * m_t + G
# Apply normalization (replicated step)
O_t = normalize(x_t)
# Update weights
W_t = weights - learning_rate * O_t
return W_t, m_t
# Example usage in a distributed training loop
for epoch
Tags
Original Sources
↗ https://main-horse.github.io/posts/parallelizing-muon/?utm_source=tldrai
About the author
Kai built ML infrastructure at a Bay Area startup before developing an obsession with transformer architectures and inference optimisation that eventually pulled him out of product work entirely. A stint at a compute research lab sharpened his instinct for what actually matters in a model release versus what is marketing. He writes from the inside — from the perspective of someone who has debugged the systems he is describing at three in the morning. He is allergic to hype and instinctively drawn to the unglamorous plumbing questions that everyone else skips over.
More from The Engineer →This Week's Edition
24 February 2025
88 articles
Related Articles
Related Articles
More Stories