Skip to content

Elastic Training Fast Resume #1769

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -736,4 +736,13 @@ projector_output_dim_for_vit: 4096
rope_theta_for_vit: 10000
vision_output_dim_for_vit: 4096
pixel_shuffle_ratio_for_vit: 0.5
projector_dropout_for_vit: 0.0
projector_dropout_for_vit: 0.0


## Elastic training flags
elastic_mode: "fast-resume"
elastic_reshard_check_period: 1
elastic_snapshot_period: 5
elastic_max_elastic_down_event_count: 100
elastic_max_reshard_retry_count: 3
elastic_wait_period: 30
151 changes: 85 additions & 66 deletions MaxText/elastic_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def elastic_handler(
)
with mesh:
data_iterator, _ = create_data_iterator(config, mesh)
input_data_shardings = maxtext_utils.get_input_data_sharding(config, mesh)

step, snapshot_jax_arrays, _ = elastic_manager.get_resharded_snapshot(mesh)

Expand Down Expand Up @@ -185,6 +186,7 @@ def elastic_handler(
learning_rate_schedule,
metric_logger,
writer,
input_data_shardings,
)


Expand Down Expand Up @@ -278,6 +280,50 @@ def train_loop(config, elastic_manager, state=None):
# the step is restored back to the latest snapshot when a slice is lost
while step < config.steps:
try:
elastic_manager.maybe_snapshot(
step=step,
snapshot_jax_arrays={
"params": state.params,
"opt_state": state.opt_state,
},
block=True,
)

if (config.elastic_mode == "fast-resume" and
elastic_manager.good_slice_count < elastic_manager.total_slice_count):
wait_for_all_slices(elastic_manager, config.elastic_wait_period)

ret = elastic_manager.maybe_reshard_up(
step=step,
snapshot_jax_arrays={
"params": state.params,
"opt_state": state.opt_state,
},
elastic_handler=elastic_handler,
handler_kwargs={
"config": config,
"elastic_manager": elastic_manager,
"checkpoint_manager": checkpoint_manager,
},
)

if ret is not None:
(
config,
step,
state,
mesh,
checkpoint_manager,
data_iterator,
p_train_step,
example_batch,
learning_rate_schedule,
metric_logger,
writer,
input_data_shardings,
) = ret
step += 1

if step == first_profiling_step or prof.should_activate_periodic_profile(step):
optional_postfix = f"step_{step}" if config.profile_periodically_period > 0 else ""
prof.activate(blocking_object=state, optional_postfix=optional_postfix)
Expand Down Expand Up @@ -320,50 +366,27 @@ def train_loop(config, elastic_manager, state=None):
if step == last_profiling_step or prof.should_deactivate_periodic_profile(step):
prof.deactivate(blocking_object=state)

elastic_manager.maybe_snapshot(
step=step,
snapshot_jax_arrays={
"params": state.params,
"opt_state": state.opt_state,
},
block=True,
)

ret = elastic_manager.maybe_reshard_up(
step=step,
snapshot_jax_arrays={
"params": state.params,
"opt_state": state.opt_state,
},
elastic_handler=elastic_handler,
handler_kwargs={
"config": config,
"elastic_manager": elastic_manager,
"checkpoint_manager": checkpoint_manager,
},
)
if ret is not None:
(
config,
step,
state,
mesh,
checkpoint_manager,
data_iterator,
p_train_step,
example_batch,
learning_rate_schedule,
metric_logger,
writer,
) = ret

if step == start_step:
max_utils.print_mem_stats("After params initialized")

step += 1

except jax.errors.JaxRuntimeError as error:
ret = elastic_manager.maybe_reshard_down(
(
config,
step,
state,
mesh,
checkpoint_manager,
data_iterator,
p_train_step,
example_batch,
learning_rate_schedule,
metric_logger,
writer,
input_data_shardings,
) = elastic_manager.maybe_reshard_down(
error=error,
elastic_handler=elastic_handler,
handler_kwargs={
Expand All @@ -372,20 +395,6 @@ def train_loop(config, elastic_manager, state=None):
"checkpoint_manager": checkpoint_manager,
},
)
if ret is not None:
(
config,
step,
state,
mesh,
checkpoint_manager,
data_iterator,
p_train_step,
example_batch,
learning_rate_schedule,
metric_logger,
writer,
) = ret

if checkpoint_manager is not None:
if (int(state.step) - 1) % config.checkpoint_period != 0:
Expand Down Expand Up @@ -422,16 +431,21 @@ def train_loop(config, elastic_manager, state=None):
return state


def wait_for_all_slices(elastic_manager: manager.Manager) -> None:
elastic_manager.good_slice_indices = elastic_manager.get_slice_availability()
while len(elastic_manager.good_slice_indices) < elastic_manager.total_slice_count:
def wait_for_all_slices(
elastic_manager: manager.Manager,
wait_period: int,
) -> set[int]:
good_slice_indices = elastic_manager.get_slice_availability()
while len(good_slice_indices) < elastic_manager.total_slice_count:
max_logging.log(
f"Only {elastic_manager.good_slice_count} slices out of {elastic_manager.total_slice_count} available. "
"Sleeping for 5 seconds."
f"Only {len(good_slice_indices)} slices out of {elastic_manager.total_slice_count} available. "
f"Sleeping for {wait_period} seconds."
)
time.sleep(5)
elastic_manager.good_slice_indices = elastic_manager.get_slice_availability()
time.sleep(wait_period)
good_slice_indices = elastic_manager.get_slice_availability()

max_logging.log("All slices are available")
return good_slice_indices


def elastic_initialize(devices: Sequence[jax.Device]) -> manager.Manager:
Expand All @@ -443,17 +457,11 @@ def elastic_initialize(devices: Sequence[jax.Device]) -> manager.Manager:
Returns:
The initialized elastic manager
"""
elastic_manager = manager.Manager(
devices,
reshard_check_period=1,
snapshot_period=5,
max_elastic_down_event_count=100,
max_reshard_retry_count=3,
)
elastic_manager = manager.Manager(devices)

# Do not start training until all slices are available
# TODO: b/408455557 - Migrate to pathwaysutils and make configurable
wait_for_all_slices(elastic_manager)
elastic_manager.good_slice_indices = wait_for_all_slices(elastic_manager, 30)

pyconfig.HyperParameters.global_batch_size_to_train_on = property(
lambda self: elastic_manager.scale_by_good_slices(self.get_keys()["global_batch_size_to_train_on"])
Expand All @@ -468,6 +476,14 @@ def elastic_initialize(devices: Sequence[jax.Device]) -> manager.Manager:

return elastic_manager

def elastic_configure(
config: pyconfig.HyperParameters,
elastic_manager: manager.Manager,
):
elastic_manager.reshard_check_period = config.elastic_reshard_check_period
elastic_manager.snapshot_period = config.elastic_snapshot_period
elastic_manager.max_elastic_down_event_count = config.elastic_max_elastic_down_event_count
elastic_manager.max_reshard_retry_count = config.elastic_max_reshard_retry_count

def main(argv: Sequence[str]) -> None:
pathwaysutils.initialize()
Expand All @@ -482,6 +498,9 @@ def main(argv: Sequence[str]) -> None:
elastic_manager = elastic_initialize(jax.devices())

config = pyconfig.initialize(argv)

elastic_configure(config, elastic_manager)

max_utils.print_system_information()
validate_train_config(config)
os.environ["TFDS_DATA_DIR"] = config.dataset_path or ""
Expand Down
34 changes: 34 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,37 @@ def validate_rope_type(rope_type: str) -> None:
raise ValueError(f"Invalid RoPE type was passed. Got: {rope_type}. Valid options: {valid_rope_types}")


def validate_elastic(
elastic_mode: str | None,
elastic_reshard_check_period: int | None,
):
modes = {
"replica-resize",
"fast-resume",
}

if elastic_mode not in modes:
raise ValueError(f"{elastic_mode=} must be in {modes}")

if elastic_mode == "fast-resume" and elastic_reshard_check_period not in {None, 1}:
raise ValueError(
"For {elastic_mode=}, {elastic_reshard_check_period=} must be None or 1"
)


def get_elastic_defaults(keys) -> tuple[str, int | None]:
elastic_defaults = {
"elastic_mode": "replica-resize",
"elastic_reshard_check_period": 1,
"elastic_snapshot_period": 1,
"elastic_max_elastic_down_event_count": None,
"elastic_max_reshard_retry_count": None,
"elastic_wait_period": 30,
}

return {k: v for k, v in elastic_defaults.items() if k not in keys}


def validate_keys(keys):
validate_attention_kernel(keys["attention"])
validate_attention_type(keys["attention_type"])
Expand All @@ -160,6 +191,7 @@ def validate_keys(keys):
validate_model_call_mode(keys["model_call_mode"])
validate_prefill_and_target_lengths(keys["max_prefill_predict_length"], keys["max_target_length"])
validate_rope_type(keys["rope_type"])
validate_elastic(keys["elastic_mode"], keys["elastic_reshard_check_period"])

assert (keys["load_parameters_path"] == "" and keys["load_full_state_path"] == "") or keys[
"enable_checkpointing"
Expand Down Expand Up @@ -592,6 +624,8 @@ def user_init(raw_keys):

raw_keys["decoder_block"] = DecoderBlockType(raw_keys["decoder_block"])

raw_keys |= get_elastic_defaults(raw_keys)

@staticmethod
def configure_gpt3_task(raw_keys):
"""dynamically configure gpt3 task based on training rules"""
Expand Down
Loading