Skip to content

[wip][core] parallel loading of shards #12028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
if is_accelerate_available():
from accelerate import dispatch_model, init_empty_weights

from ..models.modeling_utils import load_model_dict_into_meta
from ..models.model_loading_utils import load_model_dict_into_meta


SINGLE_FILE_LOADABLE_CLASSES = {
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
if is_accelerate_available():
from accelerate import init_empty_weights

from ..models.modeling_utils import load_model_dict_into_meta
from ..models.model_loading_utils import load_model_dict_into_meta

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
ImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..models.model_loading_utils import load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..models.model_loading_utils import load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..utils import is_accelerate_available, is_torch_version, logging
from ..utils.torch_utils import empty_device_cache

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
from ..models.model_loading_utils import load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
Expand Down
125 changes: 125 additions & 0 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
from array import array
from collections import OrderedDict, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
Expand Down Expand Up @@ -310,6 +311,130 @@ def load_model_dict_into_meta(
return offload_index, state_dict_index


def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved it here from modeling_utils.py.

"""
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
parameters.

"""
if model_to_load.device.type == "meta":
return False

if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
return False

# Some models explicitly do not support param buffer assignment
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
return False

# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = next(iter(model_to_load.state_dict().keys()))
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype

return False


def load_shard_file(args):
(
model,
model_state_dict,
shard_file,
device_map,
dtype,
hf_quantizer,
keep_in_fp32_modules,
dduf_entries,
loaded_keys,
unexpected_keys,
offload_index,
offload_folder,
state_dict_index,
state_dict_folder,
ignore_mismatched_sizes,
low_cpu_mem_usage,
) = args
assign_to_params_buffers = None
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
)
error_msgs = []
if low_cpu_mem_usage:
offload_index, state_dict_index = load_model_dict_into_meta(
model,
state_dict,
device_map=device_map,
dtype=dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_index=state_dict_index,
state_dict_folder=state_dict_folder,
)
else:
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)

error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
return offload_index, state_dict_index, mismatched_keys, error_msgs


def load_shard_files_with_threadpool(args_list):
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))

# Do not spawn anymore workers than you need
num_workers = min(len(args_list), num_workers)

logger.info(f"Loading model weights in parallel with {num_workers} workers...")

error_msgs = []
mismatched_keys = []

with ThreadPoolExecutor(max_workers=num_workers) as executor:
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
futures = [executor.submit(load_shard_file, arg) for arg in args_list]
for future in as_completed(futures):
result = future.result()
offload_index, state_dict_index, _mismatched_keys, _error_msgs = result
error_msgs += _error_msgs
mismatched_keys += _mismatched_keys
pbar.update(1)

return offload_index, state_dict_index, mismatched_keys, error_msgs


def _find_mismatched_keys(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. Moved it out of modeling_utils.py.

state_dict,
model_state_dict,
loaded_keys,
ignore_mismatched_sizes,
):
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
# If the checkpoint is sharded, we may not have the key here.
if checkpoint_key not in state_dict:
continue

if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
return mismatched_keys


def _load_state_dict_into_model(
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:
Expand Down
105 changes: 46 additions & 59 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
CONFIG_NAME,
ENV_VARS_TRUE_VALUES,
FLAX_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_WEIGHTS_NAME,
Expand Down Expand Up @@ -69,9 +70,8 @@
_expand_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_find_mismatched_keys,
_load_state_dict_into_model,
load_model_dict_into_meta,
load_shard_file,
load_shard_files_with_threadpool,
load_state_dict,
)

Expand Down Expand Up @@ -208,34 +208,6 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
return last_tuple[1].dtype


def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
"""
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
parameters.

"""
if model_to_load.device.type == "meta":
return False

if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
return False

# Some models explicitly do not support param buffer assignment
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
logger.debug(
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
)
return False

# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
first_key = next(iter(model_to_load.state_dict().keys()))
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype

return False


@contextmanager
def no_init_weights():
"""
Expand Down Expand Up @@ -988,6 +960,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)

is_parallel_loading_enabled = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
if is_parallel_loading_enabled and not low_cpu_mem_usage:
raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.")

if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
Expand Down Expand Up @@ -1323,6 +1299,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
is_parallel_loading_enabled=is_parallel_loading_enabled,
)
loading_info = {
"missing_keys": missing_keys,
Expand Down Expand Up @@ -1518,6 +1495,7 @@ def _load_pretrained_model(
offload_state_dict: Optional[bool] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
is_parallel_loading_enabled: Optional[bool] = False,
):
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
Expand All @@ -1531,6 +1509,9 @@ def _load_pretrained_model(
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

mismatched_keys = []
error_msgs = []

# Deal with offload
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
Expand Down Expand Up @@ -1566,37 +1547,43 @@ def _load_pretrained_model(
# if state dict is not None, it means that we don't need to read the files from resolved_model_file also
resolved_model_file = [state_dict]

if len(resolved_model_file) > 1:
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")

mismatched_keys = []
assign_to_params_buffers = None
error_msgs = []
# prepare the arguments.
args_list = [
(
model,
model_state_dict,
shard_file,
device_map,
dtype,
hf_quantizer,
keep_in_fp32_modules,
dduf_entries,
loaded_keys,
unexpected_keys,
offload_index,
offload_folder,
state_dict_index,
state_dict_folder,
ignore_mismatched_sizes,
low_cpu_mem_usage,
)
for shard_file in resolved_model_file
]

for shard_file in resolved_model_file:
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
mismatched_keys += _find_mismatched_keys(
state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes
if is_parallel_loading_enabled:
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_files_with_threadpool(
args_list
)
error_msgs += _error_msgs
mismatched_keys += _mismatched_keys
else:
if len(args_list) > 1:
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")

if low_cpu_mem_usage:
offload_index, state_dict_index = load_model_dict_into_meta(
model,
state_dict,
device_map=device_map,
dtype=dtype,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_index=state_dict_index,
state_dict_folder=state_dict_folder,
)
else:
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
for args in args_list:
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_file(args)
error_msgs += _error_msgs
mismatched_keys += _mismatched_keys

empty_device_cache()

Expand Down
Loading