Skip to content

convert RMSNorm to NNX #1728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: test_748311465
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class AttentionType(enum.Enum):
DenseGeneral = linears.DenseGeneral
dense_general = linears.dense_general
RMSNorm = linears.RMSNorm
rms_norm = linears.rms_norm
RotaryEmbedding = embeddings.RotaryEmbedding
YarnRotaryEmbedding = embeddings.YarnRotaryEmbedding
NdInitializer = initializers.NdInitializer
Expand Down Expand Up @@ -1345,15 +1346,17 @@ def __call__(
is_llama4_decoder_block = self.config.decoder_block == "llama4"
# NOTE: llama 4 does L2 normalization after RoPE
if self.use_qk_norm and not is_llama4_decoder_block:
query = RMSNorm(
query = rms_norm(
features=query.shape[-1],
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
name="query_norm",
epsilon=self.config.normalization_layer_epsilon,
kernel_axes=("norm",),
)(query)

key = RMSNorm(
key = rms_norm(
features=key.shape[-1],
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
name="key_norm",
Expand Down Expand Up @@ -1486,7 +1489,8 @@ def setup(self):
quant=self.quant,
matmul_precision=self.config.matmul_precision,
)
self.q_norm = RMSNorm(
self.q_norm = rms_norm(
features=self.q_lora_rank,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
name="q_norm",
Expand Down Expand Up @@ -1519,7 +1523,8 @@ def setup(self):
quant=self.quant,
matmul_precision=self.config.matmul_precision,
)
self.kv_norm = RMSNorm(
self.kv_norm = rms_norm(
features=self.kv_lora_rank,
dtype=self.config.dtype,
weight_dtype=self.config.weight_dtype,
name="kv_norm",
Expand Down
7 changes: 5 additions & 2 deletions MaxText/layers/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
Embed = embeddings.Embed
Attention = attentions.Attention
RMSNorm = normalizations.RMSNorm
rms_norm = normalizations.rms_norm
Quant = quantizations.AqtQuantization

# -----------------------------------------
Expand All @@ -54,7 +55,8 @@

def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, decoder_positions, deterministic, model_mode):
# Normalization
lnx_rms = models.RMSNorm(
lnx_rms = rms_norm(
features=inputs.shape[-1],
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="pre_self_attention_layer_norm",
Expand Down Expand Up @@ -105,7 +107,8 @@ def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, deco
intermediate_inputs = inputs + attention_lnx

# Normalization
hidden_states = models.RMSNorm(
hidden_states = models.rms_norm(
features=intermediate_inputs.shape[-1],
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="post_self_attention_layer_norm",
Expand Down
13 changes: 7 additions & 6 deletions MaxText/layers/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

Embed = embeddings.Embed
RMSNorm = normalizations.RMSNorm
rms_norm = normalizations.rms_norm
NdInitializer = initializers.NdInitializer
Attention = attentions.Attention
MlpBlock = linears.MlpBlock
Expand Down Expand Up @@ -76,9 +77,9 @@ def __call__(
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
lnx = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm", kernel_axes=("norm",))(
inputs
)
lnx = rms_norm(
features=inputs.shape[-1], dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm", kernel_axes=("norm",)
)(inputs)

lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

Expand Down Expand Up @@ -117,9 +118,9 @@ def __call__(
)
attention_lnx += inputs
residual = attention_lnx
attn_output = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_ffw_norm", kernel_axes=("norm",))(
attention_lnx
)
attn_output = rms_norm(
features=attention_lnx.shape[-1], dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_ffw_norm", kernel_axes=("norm",)
)(attention_lnx)

# MLP block.
mlp_lnx = MlpBlock(
Expand Down
41 changes: 21 additions & 20 deletions MaxText/layers/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

Embed = embeddings.Embed
RMSNorm = normalizations.RMSNorm
rms_norm = normalizations.rms_norm
NdInitializer = initializers.NdInitializer
Attention = attentions.Attention
MlpBlock = linears.MlpBlock
Expand Down Expand Up @@ -75,8 +76,8 @@ def __call__(
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
lnx = RMSNorm(
dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm_local", kernel_axes=("norm",)
lnx = rms_norm(
features=inputs.shape[-1], dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm_local", kernel_axes=("norm",)
)(inputs)

lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
Expand Down Expand Up @@ -112,8 +113,8 @@ def __call__(
model_mode=model_mode,
)
if cfg.use_post_attn_norm:
attention_lnx = RMSNorm(
dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_self_attention_norm_local", kernel_axes=("norm",)
attention_lnx = rms_norm(
features=attention_lnx.shape[-1], dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_self_attention_norm_local", kernel_axes=("norm",)
)(attention_lnx)

attention_lnx = nn.with_logical_constraint(
Expand All @@ -122,9 +123,9 @@ def __call__(
attention_lnx += inputs
residual = attention_lnx

attn_output = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_ffw_norm_local", kernel_axes=("norm",))(
attention_lnx
)
attn_output = rms_norm(
features=attention_lnx.shape[-1], dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_ffw_norm_local", kernel_axes=("norm",)
)(attention_lnx)

# MLP block.
mlp_lnx = MlpBlock(
Expand All @@ -139,9 +140,9 @@ def __call__(
)(attn_output, deterministic=deterministic)

if cfg.use_post_ffw_norm:
mlp_lnx = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_ffw_norm_local", kernel_axes=("norm",))(
mlp_lnx
)
mlp_lnx = rms_norm(
features=mlp_lnx.shape[-1], dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_ffw_norm_local", kernel_axes=("norm",)
)(mlp_lnx)
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

next_layer_addition = mlp_lnx + residual
Expand All @@ -160,8 +161,8 @@ def __call__(
inputs = nn.with_logical_constraint(layer_output, ("activation_batch", "activation_norm_length", "activation_embed"))

# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
lnx = RMSNorm(
dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm_global", kernel_axes=("norm",)
lnx = rms_norm(
features=inputs.shape[-1], dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm_global", kernel_axes=("norm",)
)(inputs)

lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
Expand Down Expand Up @@ -196,8 +197,8 @@ def __call__(
model_mode=model_mode,
)
if cfg.use_post_attn_norm:
attention_lnx = RMSNorm(
dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_self_attention_norm_global", kernel_axes=("norm",)
attention_lnx = rms_norm(
features=attention_lnx.shape[-1], dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_self_attention_norm_global", kernel_axes=("norm",)
)(attention_lnx)

attention_lnx = nn.with_logical_constraint(
Expand All @@ -206,9 +207,9 @@ def __call__(
attention_lnx += inputs
residual = attention_lnx

attn_output = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_ffw_norm_global", kernel_axes=("norm",))(
attention_lnx
)
attn_output = rms_norm(
features=attention_lnx.shape[-1], dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_ffw_norm_global", kernel_axes=("norm",)
)(attention_lnx)

# MLP block.
mlp_lnx = MlpBlock(
Expand All @@ -222,9 +223,9 @@ def __call__(
quant=self.quant,
)(attn_output, deterministic=deterministic)
if cfg.use_post_ffw_norm:
mlp_lnx = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_ffw_norm_global", kernel_axes=("norm",))(
mlp_lnx
)
mlp_lnx = rms_norm(
features=mlp_lnx.shape[-1], dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_ffw_norm_global", kernel_axes=("norm",)
)(mlp_lnx)

mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))

Expand Down
21 changes: 12 additions & 9 deletions MaxText/layers/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

Embed = embeddings.Embed
RMSNorm = normalizations.RMSNorm
rms_norm = normalizations.rms_norm
NdInitializer = initializers.NdInitializer
Attention = attentions.Attention
AttentionType = attentions.AttentionType
Expand Down Expand Up @@ -129,9 +130,9 @@ def __call__(
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
lnx = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm", kernel_axes=("norm",))(
inputs
)
lnx = rms_norm(
features=inputs.shape[-1], dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_self_attention_norm", kernel_axes=("norm",)
)(inputs)

lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
query_pre_attn_scalar = get_query_pre_attn_scalar(cfg)
Expand Down Expand Up @@ -170,8 +171,8 @@ def __call__(
bidirectional_mask=bidirectional_mask,
)
if cfg.use_post_attn_norm:
attention_lnx = RMSNorm(
dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_self_attention_norm", kernel_axes=("norm",)
attention_lnx = rms_norm(
features=attention_lnx.shape[-1], dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_self_attention_norm", kernel_axes=("norm",)
)(attention_lnx)

attention_lnx = nn.with_logical_constraint(
Expand All @@ -180,9 +181,9 @@ def __call__(
attention_lnx += inputs
residual = attention_lnx

attn_output = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_ffw_norm", kernel_axes=("norm",))(
attention_lnx
)
attn_output = rms_norm(
features=attention_lnx.shape[-1], dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="pre_ffw_norm", kernel_axes=("norm",)
)(attention_lnx)

# MLP block.
mlp_lnx = MlpBlock(
Expand All @@ -197,7 +198,9 @@ def __call__(
)(attn_output, deterministic=deterministic)

if cfg.use_post_ffw_norm:
mlp_lnx = RMSNorm(dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_ffw_norm", kernel_axes=("norm",))(mlp_lnx)
mlp_lnx = rms_norm(
features=mlp_lnx.shape[-1], dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="post_ffw_norm", kernel_axes=("norm",)
)(mlp_lnx)

mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
next_layer_addition = mlp_lnx + residual
Expand Down
10 changes: 10 additions & 0 deletions MaxText/layers/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Callable, Tuple, Union

from flax import linen as nn
from flax import nnx
import jax
from MaxText import common_types

Expand All @@ -42,3 +43,12 @@ def init_fn(key, shape, dtype, in_axis, out_axis):
return fn(key, shape, dtype)

return init_fn

def variable_to_logically_partitioned(variable: nnx.VariableState):
metadata = variable.get_metadata()
return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args]
variable.value,
variable.sharding, # type: ignore[arg-type]
mesh=metadata.get("mesh"),
rules=metadata.get("rules"),
)
19 changes: 5 additions & 14 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
bias_init = initializers.default_bias_init

RMSNorm = normalizations.RMSNorm
rms_norm = normalizations.rms_norm
Quant = quantizations.AqtQuantization
QTensor = aqt_tensor.QTensor

Expand Down Expand Up @@ -200,16 +201,6 @@ def __call__(self, inputs: Array) -> Array:
return output


def variable_to_logically_partitioned(variable: nnx.VariableState):
metadata = variable.get_metadata()
return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args]
variable.value,
variable.sharding,
mesh=metadata.get("mesh"),
rules=metadata.get("rules"),
)


def dense_general(
*,
inputs_shape: tuple[int, ...] | None = None,
Expand Down Expand Up @@ -252,7 +243,7 @@ def dense_general(
use_bias=use_bias,
matmul_precision=matmul_precision,
name=name,
metadata_fn=variable_to_logically_partitioned,
metadata_fn=initializers.variable_to_logically_partitioned,
)
return module

Expand Down Expand Up @@ -285,9 +276,9 @@ class MlpBlock(nn.Module):
use_pre_norm: bool = False
quant: Optional[Quant] = None

def get_norm_layer(self):
def get_norm_layer(self, features: int):
if self.config.decoder_block in ("default", "llama2", "mistral", "mixtral", "gemma", "deepseek", "llama4"):
return RMSNorm
return functools.partial(rms_norm, features=features)
elif self.config.decoder_block == "gpt3":
from MaxText.layers import gpt3

Expand All @@ -301,7 +292,7 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
cfg = self.config

if self.use_pre_norm:
inputs = self.get_norm_layer()(
inputs = self.get_norm_layer(features=inputs.shape[-1])(
name="mlp_layer_norm",
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
Expand Down
7 changes: 5 additions & 2 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
Embed = embeddings.Embed
Attention = attentions.Attention
RMSNorm = normalizations.RMSNorm
rms_norm = normalizations.rms_norm
Quant = quantizations.AqtQuantization

# -----------------------------------------
Expand Down Expand Up @@ -84,7 +85,8 @@ def __call__(

inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")
lnx_rms = models.RMSNorm(
lnx_rms = models.rms_norm(
features=inputs.shape[-1],
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="pre_self_attention_layer_norm",
Expand Down Expand Up @@ -139,7 +141,8 @@ def __call__(
intermediate_inputs = inputs + attention_lnx

# Fully Connected
hidden_states = models.RMSNorm(
hidden_states = models.rms_norm(
features=intermediate_inputs.shape[-1],
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="post_self_attention_layer_norm",
Expand Down
Loading
Loading