Merge pull request #1423 from RossM/dream

Implement DREAM from http://arxiv.org/abs/2312.00210
pull/1428/head
d8ahazard 2023-12-28 08:58:32 -06:00 committed by GitHub
commit 66538199cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 88 additions and 0 deletions

View File

@ -71,6 +71,8 @@ class DreamboothConfig(BaseModel):
lr_warmup_steps: int = 500
max_token_length: int = 75
min_snr_gamma: float = 0.0
use_dream: bool = False
dream_detail_preservation: float = 0.5
mixed_precision: str = "fp16"
model_dir: str = ""
model_name: str = ""

View File

@ -1603,6 +1603,42 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# See http://arxiv.org/abs/2312.00210 (DREAM) algorithm 3
if args.use_dream and unet.config.in_channels == channels:
with torch.no_grad():
alpha_prod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps,None,None,None]
sqrt_alpha_prod = alpha_prod ** 0.5
sqrt_one_minus_alpha_prod = (1 - alpha_prod) ** 0.5
# The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
dream_lambda = (1 - alpha_prod) ** args.dream_detail_preservation
if args.model_type == "SDXL":
with accelerator.autocast():
model_pred = unet(
noisy_latents, timesteps, batch["input_ids"],
added_cond_kwargs=batch["unet_added_conditions"]
).sample
else:
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if noise_scheduler.config.prediction_type == "epsilon":
predicted_noise = model_pred
delta_noise = (noise - predicted_noise).detach()
delta_noise.mul_(dream_lambda)
latents.add_(sqrt_one_minus_alpha_prod * delta_noise)
target.add_(delta_noise)
elif noise_scheduler.config.prediction_type == "v_prediction":
predicted_noise = sqrt_one_minus_alpha_prod * noisy_latents - sqrt_alpha_prod * model_pred
delta_noise = (noise - predicted_noise).detach()
delta_noise.mul_(dream_lambda)
latents.add_(sqrt_one_minus_alpha_prod * delta_noise)
target.add_(sqrt_alpha_prod * delta_noise)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
del alpha_prod, sqrt_alpha_prod, sqrt_one_minus_alpha_prod, dream_lambda, model_pred, predicted_noise, delta_noise
if args.model_type == "SDXL":
with accelerator.autocast():
model_pred = unet(

View File

@ -296,6 +296,18 @@
data-step="0.1" id="min_snr_gamma" data-value="0.0"
data-label="Min SNR Gamma"></div>
</div>
<div class="form-group">
<div class="form-check form-switch">
<input class="dbInput form-check-input" type="checkbox"
id="use_dream" name="use_dream">
<label class="form-check-label" for="use_dream">Use DREAM</label>
</div>
</div>
<div class="form-group">
<div class="dbInput db-slider" data-min="0" data-max="1.0"
data-step="0.01" id="dream_detail_preservation" data-value="0.5"
data-label="DREAM detail preservation"></div>
</div>
<div class="form-group">
<div class="dbInput db-slider" data-min="75" data-max="300"
data-step="75" id="max_token_length" data-value="75"

View File

@ -328,6 +328,8 @@ let db_titles = {
"Use 8bit Adam": "Enable this to save VRAM.",
"Use CPU Only (SLOW)": "Guess what - this will be incredibly slow, but it will work for < 8GB GPUs.",
"Use Concepts List": "Train multiple concepts from a JSON file or string.",
"Use DREAM": "Enable DREAM (http://arxiv.org/abs/2312.00210). This may provide better results, but trains slower.",
"DREAM detail preservation": "A factor that influences how DREAM trades off composition versus detail. Low values will improve composition but may result in loss of detail. High values preserve detail but may reduce the overall effect of DREAM.",
"Use EMA": "Enabling this will provide better results and editability, but cost more VRAM.",
"Use EMA for prediction": "",
"Use EMA Weights for Inference": "Enabling this will save the EMA unet weights as the 'normal' model weights and ignore the regular unet weights.",

View File

@ -641,6 +641,17 @@ def on_ui_tabs():
step=0.1,
visible=True,
)
db_use_dream = gr.Checkbox(
label="Use DREAM", value=False
)
db_dream_detail_preservation = gr.Slider(
label="DREAM detail preservation",
minimum=0,
maximum=1,
step=0.01,
value=0.5,
visible=True,
)
db_pad_tokens = gr.Checkbox(
label="Pad Tokens", value=True
)
@ -1348,6 +1359,8 @@ def on_ui_tabs():
db_tenc_weight_decay,
db_tenc_grad_clip_norm,
db_min_snr_gamma,
db_use_dream,
db_dream_detail_preservation,
db_pad_tokens,
db_strict_tokens,
db_max_token_length,
@ -1481,6 +1494,8 @@ def on_ui_tabs():
db_lr_warmup_steps,
db_max_token_length,
db_min_snr_gamma,
db_use_dream,
db_dream_detail_preservation,
db_mixed_precision,
db_model_name,
db_model_path,

View File

@ -152,6 +152,15 @@
"max": 10,
"step": 0.1
},
"use_dream": {
"value": false
},
"dream_detail_preservation": {
"value": 0.5,
"min": 0.0,
"max": 1,
"step": 0.01
},
"max_token_length": {
"value": 75,
"min": 75,

View File

@ -38,6 +38,8 @@
"lr_warmup_steps": 500,
"max_token_length": 75,
"min_snr_gamma": 0.0,
"use_dream": false,
"dream_detail_preservation": 0.5,
"mixed_precision": "fp16",
"noise_scheduler": "DDPM",
"num_train_epochs": 200,

View File

@ -188,6 +188,16 @@
"title": "Enable exponential moving average (EMA).",
"description": "Whether or not to use exponential moving average (EMA) for model weights. Using EMA can help improve the stability and consistency of the generated images.."
},
"use_dream": {
"label": "Use DREAM",
"title": "Enable Diffusion Rectification and Estimation-Adaptive Models (DREAM).",
"description": "Whether or not to use DREAM. DREAM performs an additional model evaluation at each training step, which increases training time but can help improve the stability and consistency of the generated images."
},
"dream_detail_preservation": {
"label": "DREAM detail preservation",
"title": "Select how much detail DREAM preserves.",
"description": "A factor that influences how DREAM trades off composition versus detail. Low values will improve composition but may result in loss of detail. High values preserve detail but may reduce the overall effect of DREAM."
},
"train_unet": {
"label": "Train UNET",
"title": "Train UNET as an additional module.",