Skip to content

sharding validator Improvements #1810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,13 +874,12 @@ def cudnn_jax_flash_attention(
decoder_segment_ids: Array | None,
model_mode: str = MODEL_MODE_TRAIN,
) -> Array:
"""CUDNN Flash Attention with JAX SDPA API.
"""
"""CUDNN Flash Attention with JAX SDPA API."""
# These imports are only meant to work in a GPU build.
# pylint: disable=import-outside-toplevel
from jax._src.cudnn.fused_attention_stablehlo import (
dot_product_attention,
MaskType,
dot_product_attention,
MaskType,
)

_, _, _, head_dim = query.shape # pylint: disable=unused-variable
Expand All @@ -898,7 +897,7 @@ def cudnn_jax_flash_attention(
scale=1.0,
dropout_rate=self.dropout_rate,
qkv_layout="BTNH",
return_residual=True
return_residual=True,
)
else:
return dot_product_attention(
Expand All @@ -909,7 +908,7 @@ def cudnn_jax_flash_attention(
scale=1.0 / math.sqrt(head_dim),
dropout_rate=self.dropout_rate,
qkv_layout="BTNH",
return_residual=True
return_residual=True,
)

def compute_local_attention(
Expand Down Expand Up @@ -1124,8 +1123,9 @@ def normalize_cudnn_attention(self, local_outs, local_stats):
stat1 = local_stats[1].reshape((*local_stats[1].shape, 1))
global_stat = jnp.log(jnp.exp(stat0) + jnp.exp(stat1))
# # transpose stat to have shape [b, t, n, 1] for elemenwise multiplication
attn_out = local_outs[0].astype(jnp.float32) * jnp.exp(stat0 - global_stat).transpose((0, 2, 1, 3)) \
+ local_outs[1].astype(jnp.float32) * jnp.exp(stat1 - global_stat).transpose((0, 2, 1, 3))
attn_out = local_outs[0].astype(jnp.float32) * jnp.exp(stat0 - global_stat).transpose((0, 2, 1, 3)) + local_outs[
1
].astype(jnp.float32) * jnp.exp(stat1 - global_stat).transpose((0, 2, 1, 3))
return attn_out.astype(local_stats[0].dtype)

def normalize_attention(self, local_outs, local_maxes, local_sums):
Expand Down
12 changes: 4 additions & 8 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,10 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metdata_axis_name, mes

def get_pipeline_stage_module(self, decoder_blocks):
"""get pipeline stage module"""

def get_layer_to_pipeline(blocks, cfg):
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
return blocks[1] # return the sparse block
return blocks[1] # return the sparse block
else:
return blocks[0]

Expand Down Expand Up @@ -530,14 +531,9 @@ def __call__(
model_mode,
)
y = self.pipeline_module(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
partition_spec=partition_spec
y, decoder_segment_ids, decoder_positions, deterministic, model_mode, partition_spec=partition_spec
)
else: # Not DeepSeek
else: # Not DeepSeek
y = self.pipeline_module(
y, decoder_segment_ids, decoder_positions, deterministic, model_mode, partition_spec=partition_spec
)
Expand Down
280 changes: 253 additions & 27 deletions MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,14 @@

import numpy as np

from collections.abc import Iterable
from jax.experimental import mesh_utils
from jax.experimental.serialize_executable import deserialize_and_load
from jax.sharding import PartitionSpec as P

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

import optax

Expand Down Expand Up @@ -343,25 +346,51 @@ def calculate_prefill_tflops_per_device(num_model_parameters, prefill_length, co
return total_tflops, learnable_weight_tflops, causal_attention_tflops


def assert_params_sufficiently_sharded(params, mesh, tolerance):
"""Checks whether most params are sharded across sharding axis.
def get_mesh_axes_used_by_tensor_spec(tensor_sharding_spec):
"""
Extracts the set of mesh axis names that a tensor's PartitionSpec uses.

This function inspects a tensor's sharding specification (PartitionSpec) and
identifies which mesh axes are actively used for sharding. If a tensor is not
sharded (i.e., fully replicated), the resulting set will be empty.

Args:
tensor_sharding_spec: The PartitionSpec of a tensor, which defines how it's partitioned across the mesh.
It can be None or contain strings and iterables representing the mesh axes.
all_mesh_axis_names: A collection of all available mesh axis names in the current device mesh.

Returns:
A set of strings, where each string is a mesh axis name used by the
tensor's sharding spec. Returns an empty set for unsharded tensors.
"""
# Flatten the sharding spec, as it can contain nested iterables (e.g., ('data', 'mdl')).
tensor_sharding_spec = sum(
[
[axis] if isinstance(axis, str) else list(axis) if isinstance(axis, Iterable) else []
for axis in tensor_sharding_spec
],
[],
)
return tensor_sharding_spec


def _get_valid_target_axis(mesh):
"""
Returns mesh axes from config that are valid and have more than one shard.

This function determines whether the majority of parameters are distributed
across a specified sharding axes with an acceptable tolerance. It compares the
current distribution to a scenario where all parameters are fully sharded
across the 'fsdp', 'fsdp_transpose', 'sequence', and 'tensor' axes.
This function identifies which of the predefined potential sharding axes are
actually present in the current device mesh and are configured with a size
greater than one (i.e., are actually sharded).

Args:
params: params of the model state
mesh: mesh constructed from config
tolerance: float between 0.0 and 1.0 representing the allowed percentage of
non-sharded parameters.
mesh: The device mesh object, which contains information about the mesh topology, including axis names and their sizes.

Returns:
bool: True if the majority of parameters are sufficiently sharded
A set of strings, where each string is a mesh axis name that is both
pre-configured as a target for sharding and has more than one shard in the mesh.
"""
total_num_params = max_utils.calculate_num_params_from_pytree(params)
product_num_devices_for_weight_sharding = 1
for axis in [

target_sharding_axes_config = [
"fsdp",
"fsdp_transpose",
"sequence",
Expand All @@ -372,19 +401,152 @@ def assert_params_sufficiently_sharded(params, mesh, tolerance):
"tensor_sequence",
"stage",
"expert",
]:
product_num_devices_for_weight_sharding *= mesh.shape[axis]
total_num_params_per_chip = max_utils.calculate_total_params_per_chip(params)
perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding
assert total_num_params_per_chip >= perfectly_sharded_params_per_chip, (
"Number of parameters per chip must not be less than in the ideal sharded "
"scenario across `fsdp`, `fsdp_transpose`, `context`, `sequence`, `tensor`, `tensor_transpose`, "
"`tensor_sequence`, `stage`, `expert` axes."
)
unsharded_param_perc = total_num_params_per_chip / perfectly_sharded_params_per_chip - 1
assert unsharded_param_perc < tolerance, (
f"Number of unsharded parameters exceeds tolerance {tolerance * 100}% "
f"of total parameters with a value of {unsharded_param_perc * 100}%."
]

# Filter the target axes to find those that exist in the current mesh
# and have a size greater than 1, meaning they are actually used for sharding.
return {axis for axis in target_sharding_axes_config if axis in mesh.axis_names and mesh.shape[axis] > 1}


def _analyze_sharding(params, mesh, valid_target_mesh_axes):
"""
Analyzes parameters to find which are unsharded on any valid mesh axis.

This function iterates through all parameters in a model, checking their
sharding specifications. It identifies parameters that are not sharded along any
of the provided valid target axes (i.e., they are fully replicated across these axes).

Args:
params: A PyTree of model parameters.
mesh: The device mesh object.
valid_target_mesh_axes: A set of mesh axis names that are considered valid targets for sharding.

Returns:
A tuple containing:
- unsharded_params_total_size (int): The total size (number of elements) of all parameters found to be
unsharded on the target axes.
- problematic_tensors_details (list): A list of dictionaries, where each
dictionary contains details about a tensor that is not sharded on any of the target axes.
"""
unsharded_params_total_size = 0 # Initialize a counter for the size of unsharded parameters.
problematic_tensors_details = [] # Initialize a list to store details of problematic tensors.

# Get a flattened list of all parameters (leaves) in the PyTree, along with their paths.
all_params_leaves = jtu.tree_leaves_with_path(params)

for path, p_leaf in all_params_leaves: # Iterate over each parameter leaf
param_name_str = jtu.keystr(path) # Convert the tree path to a readable string

# Check that sharding and spec exist and are valid
sharding = getattr(p_leaf, "sharding", None)
spec = getattr(sharding, "spec", None)
assert sharding is not None and spec is not None and isinstance(spec, P), (
f"Parameter '{param_name_str}' is missing a valid '.sharding.spec'."
"Expected 'p_leaf.sharding.spec' to be a non-null 'partitionspec'."
)

current_sharding_spec = p_leaf.sharding.spec # Extract the current tensor's sharding spec
# Identify axes used for sharding
mesh_axes_used = get_mesh_axes_used_by_tensor_spec(current_sharding_spec)
# Check if the parameter is sharded on all the valid target axes.
is_sharded_on_all_target_axis = all(axis in mesh_axes_used for axis in valid_target_mesh_axes)

# If the parameter is not sharded on all of the target axes, it's considered "problematic."
if not is_sharded_on_all_target_axis:
unsharded_params_total_size += p_leaf.size # Add to total unsharded parameter size
unsharded_axes = set(valid_target_mesh_axes) - set(mesh_axes_used)
# Add detailed info to list of problematic tensors
problematic_tensors_details.append(
{
"name": param_name_str, # Tensor name
"size": p_leaf.size, # tensor size
"spec": str(current_sharding_spec), # Tensor sharding spec as string
"available_axes": sorted(list(valid_target_mesh_axes)), # Axes that could be used for sharding
"unsharded_axes": sorted(list(unsharded_axes)), # Unsharded axes
}
)
# Return the total size of unsharded parameters and the list of problematic tensors.
return unsharded_params_total_size, problematic_tensors_details # Return results


def _raise_if_unsharded_exceeds_tolerance(unsharded_size, total_size, tolerance, problematic_tensors_details):
"""
Raises an AssertionError if the percentage of unsharded parameters exceeds the given tolerance.

This function calculates the proportion of model parameters that are unsharded
and compares it against a specified tolerance. If the tolerance is exceeded,
it constructs and raises a detailed error message.

Args:
unsharded_size: The total size of parameters not sharded on target axes.
total_size: The total size of all parameters in the model.
tolerance: A float (e.g., 0.05 for 5%) representing the maximum allowed percentage of unsharded parameters.
problematic_tensors_details: A list of details about the unsharded tensors,
used to generate an informative error message.

Raises:
AssertionError: If the percentage of unsharded parameters is greater than the tolerance.
"""
# Calculate the percentage of unsharded parameters.
unsharded_param_perc = unsharded_size / total_size if total_size > 0 else 0.0

# If the percentage is over the tolerance, prepare and raise an error.
if unsharded_param_perc > tolerance:
# Sort the problematic tensors by size to show the largest ones first.
problematic_tensors_details.sort(key=lambda x: x["size"], reverse=True)

# Begin constructing the error message.
error_msg_lines = [
f"Unsharded parameter percentage ({unsharded_param_perc:.2%})" f"exceeds tolerance ({tolerance:.2%})."
]
# Add a header explaining the issue.
error_msg_lines.append(
"The following large tensors are replicated (unsharded) but could be sharded on at "
"least one of the available axes:"
)
# Add details for the top 5 largest problematic tensors.
for detail in problematic_tensors_details[:5]: # Show top 5 largest problematic tensors
error_msg_lines.append(
f" - Name: {detail['name']}(Size: {detail['size']}, Spec: {detail['spec']}) "
f"is unsharded on axis: {detail['unsharded_axes']}"
f"could be sharded on: {detail['available_axes']}"
)

# Raise the assertion error with the combined, formatted message.
raise AssertionError("\n".join(error_msg_lines))


def assert_params_sufficiently_sharded(params, mesh, tolerance):
"""
Asserts that the total size of replicated parameters is within a given tolerance.

This is the main function that orchestrates the sharding analysis. It determines
the total number of parameters, identifies valid sharding axes, analyzes the
sharding of all parameters, and then raises an error if the amount of
unsharded parameters exceeds the specified tolerance.

Args:
params: A PyTree of model parameters.
mesh: The device mesh object.
tolerance: A float representing the maximum allowed percentage of unsharded parameters.
"""
# Calculate the total size of all parameters in the model.
total_num_params = max_utils.calculate_bytes_from_pytree(params)

# Get the set of valid mesh axes that can be used for sharding.
valid_target_mesh_axes = _get_valid_target_axis(mesh)
# If there are no valid axes to shard along, there's nothing to check, so we can exit.
if not valid_target_mesh_axes:
return # Exit early

# Analyze the parameters to find the total size of unsharded parameters
# and get details on which tensors are problematic.
unsharded_params_total_size, problematic_tensors_details = _analyze_sharding(params, mesh, valid_target_mesh_axes)

# Check if the amount of unsharded parameters is within the tolerance and
# raise an exception if it is not.
_raise_if_unsharded_exceeds_tolerance(
unsharded_params_total_size, total_num_params, tolerance, problematic_tensors_details
)


Expand Down Expand Up @@ -848,3 +1010,67 @@ def schedule(step):
boundaries.append(warmup_steps + cos_steps + constant_zero_steps)

return optax.join_schedules(pieces, boundaries)


def get_formatted_sharding_annotations(params, mesh=None):
"""
Generates a readable string report of sharding annotations for all parameters.

This function iterates through a PyTree of model parameters and inspects the
sharding information attached to each parameter (leaf). It creates a
human-readable summary that is useful for debugging sharding configurations.

Args:
params: The PyTree of model parameters to inspect.
mesh: (Optional) The device mesh. If provided, its axis names and shape
are included in the report for additional context.

Returns:
A single string containing the formatted report of sharding annotations
for every parameter, with each entry on a new line.
"""
# Initialize a list to hold the lines of the report, starting with a title.
annotation_lines = ["Comprehensice Weight Sharding Annotations:"]

# If a mesh object is provided, add its details to the report header.
if mesh:
annotation_lines.append(f"Mesh axes: {mesh.axis_names}, Mesh shape: {mesh.shape}")
annotation_lines.append("-" * 30)

# Get a flattened list of all parameters (leaves) and their corresponding paths in the PyTree.
all_params_leaves = jtu.tree_leaves_with_path(params)

# Loop through each parameter leaf in the flattened list.
for path, p_leaf in all_params_leaves:
# Convert the parameter's path (a sequence of keys) into a readable string name.
param_name_str = jtu.keystr(path)
# Get the shape of the parameter as a string.
shape_str = str(p_leaf.shape)
# Set a default description for sharding, in case none is found.
sharding_desc = "N/A"

# Check if the parameter leaf has a 'sharding' attribute.
if hasattr(p_leaf, "sharding"):
# Case 1: Standard JAX sharding with a PartitionSpec.
if hasattr(p_leaf.sharding, "spec") and p_leaf.sharding.spec is not None:
# The spec is a tuple (PartitionSpec), format it for readability.
spec_parts = []
for item in p_leaf.sharding.spec:
# Represent None as "Replicated" to make it explicit.
spec_parts.append(str(item) if item is not None else "Relicated")
sharding_desc = f"PartitionSpec({', '.join(spec_parts)})"
# Case 2: The parameter is explicitly marked as fully replicated.
elif hasattr(p_leaf.sharding, "spec") and p_leaf.sharding.spec is None:
sharding_desc = "Fully Replicated (spec is None)"
# Case 3: A generic fallback if a sharding object exists but has no recognized spec attribute.
else:
# Print the string representation of the sharding object itself.
sharding_desc = str(p_leaf.sharding)
# Case 4: The parameter has no .sharding attribute at all.
else:
sharding_desc = "No .sharding attribute found"

# Append the formatted details for the current parameter to our list of lines.
annotation_lines.append(f" - Param: {param_name_str}\n" f" Shape: {shape_str}\n" f" Sharding: {sharding_desc}")
# Join all the collected lines into a single string, separated by newlines.
return "\n".join(annotation_lines)
Loading
Loading