Merge pull request #1423 from RossM/dream
Implement DREAM from http://arxiv.org/abs/2312.00210pull/1428/head
commit
66538199cd
|
|
@ -71,6 +71,8 @@ class DreamboothConfig(BaseModel):
|
|||
lr_warmup_steps: int = 500
|
||||
max_token_length: int = 75
|
||||
min_snr_gamma: float = 0.0
|
||||
use_dream: bool = False
|
||||
dream_detail_preservation: float = 0.5
|
||||
mixed_precision: str = "fp16"
|
||||
model_dir: str = ""
|
||||
model_name: str = ""
|
||||
|
|
|
|||
|
|
@ -1603,6 +1603,42 @@ def main(class_gen_method: str = "Native Diffusers", user: str = None) -> TrainR
|
|||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
# See http://arxiv.org/abs/2312.00210 (DREAM) algorithm 3
|
||||
if args.use_dream and unet.config.in_channels == channels:
|
||||
with torch.no_grad():
|
||||
alpha_prod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps,None,None,None]
|
||||
sqrt_alpha_prod = alpha_prod ** 0.5
|
||||
sqrt_one_minus_alpha_prod = (1 - alpha_prod) ** 0.5
|
||||
|
||||
# The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
|
||||
dream_lambda = (1 - alpha_prod) ** args.dream_detail_preservation
|
||||
|
||||
if args.model_type == "SDXL":
|
||||
with accelerator.autocast():
|
||||
model_pred = unet(
|
||||
noisy_latents, timesteps, batch["input_ids"],
|
||||
added_cond_kwargs=batch["unet_added_conditions"]
|
||||
).sample
|
||||
else:
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
predicted_noise = model_pred
|
||||
delta_noise = (noise - predicted_noise).detach()
|
||||
delta_noise.mul_(dream_lambda)
|
||||
latents.add_(sqrt_one_minus_alpha_prod * delta_noise)
|
||||
target.add_(delta_noise)
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
predicted_noise = sqrt_one_minus_alpha_prod * noisy_latents - sqrt_alpha_prod * model_pred
|
||||
delta_noise = (noise - predicted_noise).detach()
|
||||
delta_noise.mul_(dream_lambda)
|
||||
latents.add_(sqrt_one_minus_alpha_prod * delta_noise)
|
||||
target.add_(sqrt_alpha_prod * delta_noise)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
del alpha_prod, sqrt_alpha_prod, sqrt_one_minus_alpha_prod, dream_lambda, model_pred, predicted_noise, delta_noise
|
||||
|
||||
if args.model_type == "SDXL":
|
||||
with accelerator.autocast():
|
||||
model_pred = unet(
|
||||
|
|
|
|||
12
index.html
12
index.html
|
|
@ -296,6 +296,18 @@
|
|||
data-step="0.1" id="min_snr_gamma" data-value="0.0"
|
||||
data-label="Min SNR Gamma"></div>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<div class="form-check form-switch">
|
||||
<input class="dbInput form-check-input" type="checkbox"
|
||||
id="use_dream" name="use_dream">
|
||||
<label class="form-check-label" for="use_dream">Use DREAM</label>
|
||||
</div>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<div class="dbInput db-slider" data-min="0" data-max="1.0"
|
||||
data-step="0.01" id="dream_detail_preservation" data-value="0.5"
|
||||
data-label="DREAM detail preservation"></div>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<div class="dbInput db-slider" data-min="75" data-max="300"
|
||||
data-step="75" id="max_token_length" data-value="75"
|
||||
|
|
|
|||
|
|
@ -328,6 +328,8 @@ let db_titles = {
|
|||
"Use 8bit Adam": "Enable this to save VRAM.",
|
||||
"Use CPU Only (SLOW)": "Guess what - this will be incredibly slow, but it will work for < 8GB GPUs.",
|
||||
"Use Concepts List": "Train multiple concepts from a JSON file or string.",
|
||||
"Use DREAM": "Enable DREAM (http://arxiv.org/abs/2312.00210). This may provide better results, but trains slower.",
|
||||
"DREAM detail preservation": "A factor that influences how DREAM trades off composition versus detail. Low values will improve composition but may result in loss of detail. High values preserve detail but may reduce the overall effect of DREAM.",
|
||||
"Use EMA": "Enabling this will provide better results and editability, but cost more VRAM.",
|
||||
"Use EMA for prediction": "",
|
||||
"Use EMA Weights for Inference": "Enabling this will save the EMA unet weights as the 'normal' model weights and ignore the regular unet weights.",
|
||||
|
|
|
|||
|
|
@ -641,6 +641,17 @@ def on_ui_tabs():
|
|||
step=0.1,
|
||||
visible=True,
|
||||
)
|
||||
db_use_dream = gr.Checkbox(
|
||||
label="Use DREAM", value=False
|
||||
)
|
||||
db_dream_detail_preservation = gr.Slider(
|
||||
label="DREAM detail preservation",
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
value=0.5,
|
||||
visible=True,
|
||||
)
|
||||
db_pad_tokens = gr.Checkbox(
|
||||
label="Pad Tokens", value=True
|
||||
)
|
||||
|
|
@ -1348,6 +1359,8 @@ def on_ui_tabs():
|
|||
db_tenc_weight_decay,
|
||||
db_tenc_grad_clip_norm,
|
||||
db_min_snr_gamma,
|
||||
db_use_dream,
|
||||
db_dream_detail_preservation,
|
||||
db_pad_tokens,
|
||||
db_strict_tokens,
|
||||
db_max_token_length,
|
||||
|
|
@ -1481,6 +1494,8 @@ def on_ui_tabs():
|
|||
db_lr_warmup_steps,
|
||||
db_max_token_length,
|
||||
db_min_snr_gamma,
|
||||
db_use_dream,
|
||||
db_dream_detail_preservation,
|
||||
db_mixed_precision,
|
||||
db_model_name,
|
||||
db_model_path,
|
||||
|
|
|
|||
|
|
@ -152,6 +152,15 @@
|
|||
"max": 10,
|
||||
"step": 0.1
|
||||
},
|
||||
"use_dream": {
|
||||
"value": false
|
||||
},
|
||||
"dream_detail_preservation": {
|
||||
"value": 0.5,
|
||||
"min": 0.0,
|
||||
"max": 1,
|
||||
"step": 0.01
|
||||
},
|
||||
"max_token_length": {
|
||||
"value": 75,
|
||||
"min": 75,
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@
|
|||
"lr_warmup_steps": 500,
|
||||
"max_token_length": 75,
|
||||
"min_snr_gamma": 0.0,
|
||||
"use_dream": false,
|
||||
"dream_detail_preservation": 0.5,
|
||||
"mixed_precision": "fp16",
|
||||
"noise_scheduler": "DDPM",
|
||||
"num_train_epochs": 200,
|
||||
|
|
|
|||
|
|
@ -188,6 +188,16 @@
|
|||
"title": "Enable exponential moving average (EMA).",
|
||||
"description": "Whether or not to use exponential moving average (EMA) for model weights. Using EMA can help improve the stability and consistency of the generated images.."
|
||||
},
|
||||
"use_dream": {
|
||||
"label": "Use DREAM",
|
||||
"title": "Enable Diffusion Rectification and Estimation-Adaptive Models (DREAM).",
|
||||
"description": "Whether or not to use DREAM. DREAM performs an additional model evaluation at each training step, which increases training time but can help improve the stability and consistency of the generated images."
|
||||
},
|
||||
"dream_detail_preservation": {
|
||||
"label": "DREAM detail preservation",
|
||||
"title": "Select how much detail DREAM preserves.",
|
||||
"description": "A factor that influences how DREAM trades off composition versus detail. Low values will improve composition but may result in loss of detail. High values preserve detail but may reduce the overall effect of DREAM."
|
||||
},
|
||||
"train_unet": {
|
||||
"label": "Train UNET",
|
||||
"title": "Train UNET as an additional module.",
|
||||
|
|
|
|||
Loading…
Reference in New Issue