Fix SDXL extraction

401c60e
d8ahazard 2023-08-25 09:56:31 -05:00
parent 60603eebb3
commit 18452eb5fb
4 changed files with 326 additions and 15 deletions

View File

@ -0,0 +1,98 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@ -0,0 +1,99 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
cond_stage_trainable: true # Note: different from the one we trained before
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@ -0,0 +1,100 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.13025
disable_first_stage_autocast: True
cond_stage_trainable: true # Note: different from the one we trained before
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
params:
num_idx: 1000
weighting_config:
target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
network_config:
target: sgm.modules.diffusionmodules.openaimodel.UNetModel
params:
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: True
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2]
num_res_blocks: 2
channel_mult: [1, 2, 4]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
spatial_transformer_attn_type: softmax-xformers
legacy: False
unfreeze_model: True
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
params:
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
params:
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256 # multiplied by two
first_stage_config:
target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

View File

@ -23,6 +23,7 @@ import traceback
from typing import Union
import torch
from diffusers import StableDiffusionXLPipeline
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
from dreambooth import shared
@ -47,9 +48,10 @@ def get_config_path(
model_version: str = "v1",
train_type: str = "default",
config_base_name: str = "training",
prediction_type: str = "epsilon"
prediction_type: str = ""
):
train_type = f"{train_type}" if not prediction_type == "v_prediction" else f"{train_type}-v"
if prediction_type != "":
train_type = f"{train_type}-{prediction_type}"
return os.path.join(
os.path.dirname(os.path.realpath(__file__)),
@ -59,24 +61,29 @@ def get_config_path(
)
def get_config_file(train_unfrozen=False, v2=False, prediction_type="epsilon"):
def get_config_file(train_unfrozen=False, model_type: str="v1x"):
config_base_name = "training"
model_versions = {
"v1": "v1",
"v2": "v2"
"v1x": "v1",
"v2x-512": "v2",
"v2x": "v2",
"SDXL": "SDXL",
}
model_pred_string = {
"v1x": "",
"v2x-512": "",
"v2x": "v",
"SDXL": "",
}
train_types = {
"default": "default",
"unfrozen": "unfrozen",
}
model_train_type = train_types["default"]
model_version_name = f"{model_versions['v1'] if not v2 else model_versions['v2']}"
if train_unfrozen:
model_train_type = train_types["unfrozen"]
model_train_type = train_types["default"] if not train_unfrozen else train_types["unfrozen"]
model_version_name = model_versions[model_type]
prediction_type = model_pred_string[model_type]
return get_config_path(model_version_name, model_train_type, config_base_name, prediction_type)
@ -129,9 +136,13 @@ def extract_checkpoint(
# modules.shared.status.update(status)
disable_safe_unpickle()
if image_size is None:
image_size = 512 if is_512 else 768
image_size = 512
if model_type == "v2x":
image_size = 768
if model_type == "SDXL":
image_size = 1024
to_safetensors = True
to_safetensors = False
if pipeline_class_name is not None:
library = importlib.import_module("diffusers")
class_obj = getattr(library, pipeline_class_name)
@ -140,7 +151,7 @@ def extract_checkpoint(
pipeline_class = None
if original_config_file is None:
original_config_file = get_config_file(train_unfrozen, v2=is_512 == False, prediction_type=prediction_type)
original_config_file = get_config_file(train_unfrozen, model_type)
print(f"Extracting config from {original_config_file}")
checkpoint_file = os.path.join(shared.models_path, checkpoint_file)
print(f"Extracting checkpoint from {checkpoint_file}")
@ -153,6 +164,8 @@ def extract_checkpoint(
db_config.model_type = model_type
db_config.resolution = image_size
db_config.save()
if model_type == "SDXL":
pipeline_class = StableDiffusionXLPipeline
try:
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path=checkpoint_file,
@ -179,12 +192,13 @@ def extract_checkpoint(
dump_path = db_config.get_pretrained_model_name_or_path()
if controlnet:
print("Saving controlnet model")
# only save the controlnet model
pipe.controlnet.save_pretrained(dump_path, safe_serialization=to_safetensors)
else:
try:
tmp_path = f"{dump_path}_tmp"
pipe.save_pretrained(dump_path, safe_serialization=True)
pipe.save_pretrained(dump_path, safe_serialization=False)
except:
print("Couldn't save the pipe")
traceback.print_exc()