Skip to content

Add --lora_alpha and metadata handling for train_dreambooth_lora_hidream #11765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
d6e4a1a
update the test script.
ParagEkbote Jun 21, 2025
0dc1cdd
update the test case file.
ParagEkbote Jun 21, 2025
d52c1ca
Merge branch 'main' into Add-Lora-Alpha-For-HiDream-Lora
ParagEkbote Jun 22, 2025
85253ee
Merge branch 'main' into Add-Lora-Alpha-For-HiDream-Lora
linoytsaban Jun 23, 2025
8d4ad65
make style.
ParagEkbote Jun 23, 2025
9e613a3
Merge branch 'Add-Lora-Alpha-For-HiDream-Lora' of https://github.com/…
ParagEkbote Jun 23, 2025
9a7f5c3
Merge branch 'main' into Add-Lora-Alpha-For-HiDream-Lora
ParagEkbote Jun 23, 2025
e7fc90c
Merge branch 'main' into Add-Lora-Alpha-For-HiDream-Lora
linoytsaban Jun 23, 2025
b129892
Merge branch 'main' into Add-Lora-Alpha-For-HiDream-Lora
linoytsaban Jun 24, 2025
dc2f277
Merge branch 'main' into Add-Lora-Alpha-For-HiDream-Lora
linoytsaban Jun 30, 2025
ee2c220
remove not uswed param
ParagEkbote Jun 30, 2025
caf7dc9
Merge branch 'Add-Lora-Alpha-For-HiDream-Lora' of https://github.com/…
ParagEkbote Jun 30, 2025
2b21a85
Merge branch 'main' into Add-Lora-Alpha-For-HiDream-Lora
ParagEkbote Jun 30, 2025
dd3d25c
Merge branch 'main' into Add-Lora-Alpha-For-HiDream-Lora
ParagEkbote Jul 10, 2025
fc53122
Merge branch 'main' into Add-Lora-Alpha-For-HiDream-Lora
ParagEkbote Jul 11, 2025
6c8314d
Merge branch 'huggingface:main' into Add-Lora-Alpha-For-HiDream-Lora
ParagEkbote Jul 14, 2025
3b9dad8
fix test failures.
ParagEkbote Jul 14, 2025
6bafe8f
Merge branch 'main' into Add-Lora-Alpha-For-HiDream-Lora
linoytsaban Jul 16, 2025
2ffd45d
Merge branch 'main' into Add-Lora-Alpha-For-HiDream-Lora
ParagEkbote Jul 25, 2025
a66cf63
Merge branch 'main' into Add-Lora-Alpha-For-HiDream-Lora
linoytsaban Jul 31, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions examples/dreambooth/test_dreambooth_lora_hidream.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import logging
import os
import sys
import tempfile

import safetensors

from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY


sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
Expand All @@ -34,6 +37,7 @@

class DreamBoothLoRAHiDreamImage(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-hidream-i1-pipe"
text_encoder_4_path = "hf-internal-testing/tiny-random-LlamaForCausalLM"
tokenizer_4_path = "hf-internal-testing/tiny-random-LlamaForCausalLM"
Expand Down Expand Up @@ -175,6 +179,48 @@ def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit(self):
{"checkpoint-4", "checkpoint-6"},
)

def test_dreambooth_lora_with_metadata(self):
# Use a `lora_alpha` that is different from `rank`.
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_alpha={lora_alpha}
--rank={rank}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)
# save_pretrained smoke test
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))

# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}

metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)

loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)

def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
Expand All @@ -183,6 +229,7 @@ def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_m
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir={self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
Expand All @@ -203,6 +250,7 @@ def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_m
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir={self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
Expand Down
15 changes: 14 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_hidream.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
Expand Down Expand Up @@ -421,6 +422,13 @@ def parse_args(input_args=None):

parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")

parser.add_argument(
"--lora_alpha",
type=int,
default=4,
help="LoRA alpha to be used for additional scaling.",
)

parser.add_argument(
"--with_prior_preservation",
default=False,
Expand Down Expand Up @@ -1164,7 +1172,7 @@ def main(args):
# now we will add new LoRA weights the transformer layers
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
Expand All @@ -1181,10 +1189,12 @@ def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None

modules_to_save = {}
for model in models:
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
model = unwrap_model(model)
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")

Expand All @@ -1195,6 +1205,7 @@ def save_model_hook(models, weights, output_dir):
HiDreamImagePipeline.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
**_collate_lora_metadata(modules_to_save),
)

def load_model_hook(models, input_dir):
Expand Down Expand Up @@ -1488,6 +1499,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
modules_to_save = {}
tracker_name = "dreambooth-hidream-lora"
accelerator.init_trackers(tracker_name, config=vars(args))

Expand Down Expand Up @@ -1727,6 +1739,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer

HiDreamImagePipeline.save_lora_weights(
save_directory=args.output_dir,
Expand Down