mirror of https://github.com/vladmandic/automatic
Initial support for Chroma
parent
26800a1ef9
commit
4b3ce06916
|
|
@ -60,6 +60,8 @@ def guess_dct(dct: dict):
|
|||
if has(dct, 'model.diffusion_model.joint_blocks') and len(list(has(dct, 'model.diffusion_model.joint_blocks'))) == 38:
|
||||
return 'sd35-large'
|
||||
if has(dct, 'model.diffusion_model.double_blocks') and len(list(has(dct, 'model.diffusion_model.double_blocks'))) == 19:
|
||||
if has(dct, 'model.diffusion_model.distilled_guidance_layer'):
|
||||
return 'chroma'
|
||||
return 'flux-dev'
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"_class_name": "ChromaPipeline",
|
||||
"_diffusers_version": "0.34.0.dev0",
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"FlowMatchEulerDiscreteScheduler"
|
||||
],
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"T5EncoderModel"
|
||||
],
|
||||
"tokenizer": [
|
||||
"transformers",
|
||||
"T5Tokenizer"
|
||||
],
|
||||
"transformer": [
|
||||
"diffusers",
|
||||
"ChromaTransformer2DModel"
|
||||
],
|
||||
"vae": [
|
||||
"diffusers",
|
||||
"AutoencoderKL"
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
{
|
||||
"_class_name": "FlowMatchEulerDiscreteScheduler",
|
||||
"_diffusers_version": "0.34.0.dev0",
|
||||
"base_image_seq_len": 256,
|
||||
"base_shift": 0.5,
|
||||
"invert_sigmas": false,
|
||||
"max_image_seq_len": 4096,
|
||||
"max_shift": 1.15,
|
||||
"num_train_timesteps": 1000,
|
||||
"shift": 3.0,
|
||||
"shift_terminal": null,
|
||||
"stochastic_sampling": false,
|
||||
"time_shift_type": "exponential",
|
||||
"use_beta_sigmas": false,
|
||||
"use_dynamic_shifting": true,
|
||||
"use_exponential_sigmas": false,
|
||||
"use_karras_sigmas": false
|
||||
}
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
{
|
||||
"_name_or_path": "google/t5-v1_1-xxl",
|
||||
"architectures": [
|
||||
"T5EncoderModel"
|
||||
],
|
||||
"classifier_dropout": 0.0,
|
||||
"d_ff": 10240,
|
||||
"d_kv": 64,
|
||||
"d_model": 4096,
|
||||
"decoder_start_token_id": 0,
|
||||
"dense_act_fn": "gelu_new",
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"feed_forward_proj": "gated-gelu",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"is_gated_act": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 24,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_max_distance": 128,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.52.4",
|
||||
"use_cache": true,
|
||||
"vocab_size": 32128
|
||||
}
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
{
|
||||
"<extra_id_0>": 32099,
|
||||
"<extra_id_10>": 32089,
|
||||
"<extra_id_11>": 32088,
|
||||
"<extra_id_12>": 32087,
|
||||
"<extra_id_13>": 32086,
|
||||
"<extra_id_14>": 32085,
|
||||
"<extra_id_15>": 32084,
|
||||
"<extra_id_16>": 32083,
|
||||
"<extra_id_17>": 32082,
|
||||
"<extra_id_18>": 32081,
|
||||
"<extra_id_19>": 32080,
|
||||
"<extra_id_1>": 32098,
|
||||
"<extra_id_20>": 32079,
|
||||
"<extra_id_21>": 32078,
|
||||
"<extra_id_22>": 32077,
|
||||
"<extra_id_23>": 32076,
|
||||
"<extra_id_24>": 32075,
|
||||
"<extra_id_25>": 32074,
|
||||
"<extra_id_26>": 32073,
|
||||
"<extra_id_27>": 32072,
|
||||
"<extra_id_28>": 32071,
|
||||
"<extra_id_29>": 32070,
|
||||
"<extra_id_2>": 32097,
|
||||
"<extra_id_30>": 32069,
|
||||
"<extra_id_31>": 32068,
|
||||
"<extra_id_32>": 32067,
|
||||
"<extra_id_33>": 32066,
|
||||
"<extra_id_34>": 32065,
|
||||
"<extra_id_35>": 32064,
|
||||
"<extra_id_36>": 32063,
|
||||
"<extra_id_37>": 32062,
|
||||
"<extra_id_38>": 32061,
|
||||
"<extra_id_39>": 32060,
|
||||
"<extra_id_3>": 32096,
|
||||
"<extra_id_40>": 32059,
|
||||
"<extra_id_41>": 32058,
|
||||
"<extra_id_42>": 32057,
|
||||
"<extra_id_43>": 32056,
|
||||
"<extra_id_44>": 32055,
|
||||
"<extra_id_45>": 32054,
|
||||
"<extra_id_46>": 32053,
|
||||
"<extra_id_47>": 32052,
|
||||
"<extra_id_48>": 32051,
|
||||
"<extra_id_49>": 32050,
|
||||
"<extra_id_4>": 32095,
|
||||
"<extra_id_50>": 32049,
|
||||
"<extra_id_51>": 32048,
|
||||
"<extra_id_52>": 32047,
|
||||
"<extra_id_53>": 32046,
|
||||
"<extra_id_54>": 32045,
|
||||
"<extra_id_55>": 32044,
|
||||
"<extra_id_56>": 32043,
|
||||
"<extra_id_57>": 32042,
|
||||
"<extra_id_58>": 32041,
|
||||
"<extra_id_59>": 32040,
|
||||
"<extra_id_5>": 32094,
|
||||
"<extra_id_60>": 32039,
|
||||
"<extra_id_61>": 32038,
|
||||
"<extra_id_62>": 32037,
|
||||
"<extra_id_63>": 32036,
|
||||
"<extra_id_64>": 32035,
|
||||
"<extra_id_65>": 32034,
|
||||
"<extra_id_66>": 32033,
|
||||
"<extra_id_67>": 32032,
|
||||
"<extra_id_68>": 32031,
|
||||
"<extra_id_69>": 32030,
|
||||
"<extra_id_6>": 32093,
|
||||
"<extra_id_70>": 32029,
|
||||
"<extra_id_71>": 32028,
|
||||
"<extra_id_72>": 32027,
|
||||
"<extra_id_73>": 32026,
|
||||
"<extra_id_74>": 32025,
|
||||
"<extra_id_75>": 32024,
|
||||
"<extra_id_76>": 32023,
|
||||
"<extra_id_77>": 32022,
|
||||
"<extra_id_78>": 32021,
|
||||
"<extra_id_79>": 32020,
|
||||
"<extra_id_7>": 32092,
|
||||
"<extra_id_80>": 32019,
|
||||
"<extra_id_81>": 32018,
|
||||
"<extra_id_82>": 32017,
|
||||
"<extra_id_83>": 32016,
|
||||
"<extra_id_84>": 32015,
|
||||
"<extra_id_85>": 32014,
|
||||
"<extra_id_86>": 32013,
|
||||
"<extra_id_87>": 32012,
|
||||
"<extra_id_88>": 32011,
|
||||
"<extra_id_89>": 32010,
|
||||
"<extra_id_8>": 32091,
|
||||
"<extra_id_90>": 32009,
|
||||
"<extra_id_91>": 32008,
|
||||
"<extra_id_92>": 32007,
|
||||
"<extra_id_93>": 32006,
|
||||
"<extra_id_94>": 32005,
|
||||
"<extra_id_95>": 32004,
|
||||
"<extra_id_96>": 32003,
|
||||
"<extra_id_97>": 32002,
|
||||
"<extra_id_98>": 32001,
|
||||
"<extra_id_99>": 32000,
|
||||
"<extra_id_9>": 32090
|
||||
}
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
{
|
||||
"additional_special_tokens": [
|
||||
"<extra_id_0>",
|
||||
"<extra_id_1>",
|
||||
"<extra_id_2>",
|
||||
"<extra_id_3>",
|
||||
"<extra_id_4>",
|
||||
"<extra_id_5>",
|
||||
"<extra_id_6>",
|
||||
"<extra_id_7>",
|
||||
"<extra_id_8>",
|
||||
"<extra_id_9>",
|
||||
"<extra_id_10>",
|
||||
"<extra_id_11>",
|
||||
"<extra_id_12>",
|
||||
"<extra_id_13>",
|
||||
"<extra_id_14>",
|
||||
"<extra_id_15>",
|
||||
"<extra_id_16>",
|
||||
"<extra_id_17>",
|
||||
"<extra_id_18>",
|
||||
"<extra_id_19>",
|
||||
"<extra_id_20>",
|
||||
"<extra_id_21>",
|
||||
"<extra_id_22>",
|
||||
"<extra_id_23>",
|
||||
"<extra_id_24>",
|
||||
"<extra_id_25>",
|
||||
"<extra_id_26>",
|
||||
"<extra_id_27>",
|
||||
"<extra_id_28>",
|
||||
"<extra_id_29>",
|
||||
"<extra_id_30>",
|
||||
"<extra_id_31>",
|
||||
"<extra_id_32>",
|
||||
"<extra_id_33>",
|
||||
"<extra_id_34>",
|
||||
"<extra_id_35>",
|
||||
"<extra_id_36>",
|
||||
"<extra_id_37>",
|
||||
"<extra_id_38>",
|
||||
"<extra_id_39>",
|
||||
"<extra_id_40>",
|
||||
"<extra_id_41>",
|
||||
"<extra_id_42>",
|
||||
"<extra_id_43>",
|
||||
"<extra_id_44>",
|
||||
"<extra_id_45>",
|
||||
"<extra_id_46>",
|
||||
"<extra_id_47>",
|
||||
"<extra_id_48>",
|
||||
"<extra_id_49>",
|
||||
"<extra_id_50>",
|
||||
"<extra_id_51>",
|
||||
"<extra_id_52>",
|
||||
"<extra_id_53>",
|
||||
"<extra_id_54>",
|
||||
"<extra_id_55>",
|
||||
"<extra_id_56>",
|
||||
"<extra_id_57>",
|
||||
"<extra_id_58>",
|
||||
"<extra_id_59>",
|
||||
"<extra_id_60>",
|
||||
"<extra_id_61>",
|
||||
"<extra_id_62>",
|
||||
"<extra_id_63>",
|
||||
"<extra_id_64>",
|
||||
"<extra_id_65>",
|
||||
"<extra_id_66>",
|
||||
"<extra_id_67>",
|
||||
"<extra_id_68>",
|
||||
"<extra_id_69>",
|
||||
"<extra_id_70>",
|
||||
"<extra_id_71>",
|
||||
"<extra_id_72>",
|
||||
"<extra_id_73>",
|
||||
"<extra_id_74>",
|
||||
"<extra_id_75>",
|
||||
"<extra_id_76>",
|
||||
"<extra_id_77>",
|
||||
"<extra_id_78>",
|
||||
"<extra_id_79>",
|
||||
"<extra_id_80>",
|
||||
"<extra_id_81>",
|
||||
"<extra_id_82>",
|
||||
"<extra_id_83>",
|
||||
"<extra_id_84>",
|
||||
"<extra_id_85>",
|
||||
"<extra_id_86>",
|
||||
"<extra_id_87>",
|
||||
"<extra_id_88>",
|
||||
"<extra_id_89>",
|
||||
"<extra_id_90>",
|
||||
"<extra_id_91>",
|
||||
"<extra_id_92>",
|
||||
"<extra_id_93>",
|
||||
"<extra_id_94>",
|
||||
"<extra_id_95>",
|
||||
"<extra_id_96>",
|
||||
"<extra_id_97>",
|
||||
"<extra_id_98>",
|
||||
"<extra_id_99>"
|
||||
],
|
||||
"eos_token": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
Binary file not shown.
|
|
@ -0,0 +1,940 @@
|
|||
{
|
||||
"add_prefix_space": true,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32000": {
|
||||
"content": "<extra_id_99>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32001": {
|
||||
"content": "<extra_id_98>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32002": {
|
||||
"content": "<extra_id_97>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32003": {
|
||||
"content": "<extra_id_96>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32004": {
|
||||
"content": "<extra_id_95>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32005": {
|
||||
"content": "<extra_id_94>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32006": {
|
||||
"content": "<extra_id_93>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32007": {
|
||||
"content": "<extra_id_92>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32008": {
|
||||
"content": "<extra_id_91>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32009": {
|
||||
"content": "<extra_id_90>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32010": {
|
||||
"content": "<extra_id_89>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32011": {
|
||||
"content": "<extra_id_88>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32012": {
|
||||
"content": "<extra_id_87>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32013": {
|
||||
"content": "<extra_id_86>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32014": {
|
||||
"content": "<extra_id_85>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32015": {
|
||||
"content": "<extra_id_84>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32016": {
|
||||
"content": "<extra_id_83>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32017": {
|
||||
"content": "<extra_id_82>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32018": {
|
||||
"content": "<extra_id_81>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32019": {
|
||||
"content": "<extra_id_80>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32020": {
|
||||
"content": "<extra_id_79>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32021": {
|
||||
"content": "<extra_id_78>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32022": {
|
||||
"content": "<extra_id_77>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32023": {
|
||||
"content": "<extra_id_76>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32024": {
|
||||
"content": "<extra_id_75>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32025": {
|
||||
"content": "<extra_id_74>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32026": {
|
||||
"content": "<extra_id_73>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32027": {
|
||||
"content": "<extra_id_72>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32028": {
|
||||
"content": "<extra_id_71>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32029": {
|
||||
"content": "<extra_id_70>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32030": {
|
||||
"content": "<extra_id_69>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32031": {
|
||||
"content": "<extra_id_68>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32032": {
|
||||
"content": "<extra_id_67>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32033": {
|
||||
"content": "<extra_id_66>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32034": {
|
||||
"content": "<extra_id_65>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32035": {
|
||||
"content": "<extra_id_64>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32036": {
|
||||
"content": "<extra_id_63>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32037": {
|
||||
"content": "<extra_id_62>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32038": {
|
||||
"content": "<extra_id_61>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32039": {
|
||||
"content": "<extra_id_60>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32040": {
|
||||
"content": "<extra_id_59>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32041": {
|
||||
"content": "<extra_id_58>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32042": {
|
||||
"content": "<extra_id_57>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32043": {
|
||||
"content": "<extra_id_56>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32044": {
|
||||
"content": "<extra_id_55>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32045": {
|
||||
"content": "<extra_id_54>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32046": {
|
||||
"content": "<extra_id_53>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32047": {
|
||||
"content": "<extra_id_52>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32048": {
|
||||
"content": "<extra_id_51>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32049": {
|
||||
"content": "<extra_id_50>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32050": {
|
||||
"content": "<extra_id_49>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32051": {
|
||||
"content": "<extra_id_48>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32052": {
|
||||
"content": "<extra_id_47>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32053": {
|
||||
"content": "<extra_id_46>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32054": {
|
||||
"content": "<extra_id_45>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32055": {
|
||||
"content": "<extra_id_44>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32056": {
|
||||
"content": "<extra_id_43>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32057": {
|
||||
"content": "<extra_id_42>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32058": {
|
||||
"content": "<extra_id_41>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32059": {
|
||||
"content": "<extra_id_40>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32060": {
|
||||
"content": "<extra_id_39>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32061": {
|
||||
"content": "<extra_id_38>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32062": {
|
||||
"content": "<extra_id_37>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32063": {
|
||||
"content": "<extra_id_36>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32064": {
|
||||
"content": "<extra_id_35>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32065": {
|
||||
"content": "<extra_id_34>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32066": {
|
||||
"content": "<extra_id_33>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32067": {
|
||||
"content": "<extra_id_32>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32068": {
|
||||
"content": "<extra_id_31>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32069": {
|
||||
"content": "<extra_id_30>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32070": {
|
||||
"content": "<extra_id_29>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32071": {
|
||||
"content": "<extra_id_28>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32072": {
|
||||
"content": "<extra_id_27>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32073": {
|
||||
"content": "<extra_id_26>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32074": {
|
||||
"content": "<extra_id_25>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32075": {
|
||||
"content": "<extra_id_24>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32076": {
|
||||
"content": "<extra_id_23>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32077": {
|
||||
"content": "<extra_id_22>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32078": {
|
||||
"content": "<extra_id_21>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32079": {
|
||||
"content": "<extra_id_20>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32080": {
|
||||
"content": "<extra_id_19>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32081": {
|
||||
"content": "<extra_id_18>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32082": {
|
||||
"content": "<extra_id_17>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32083": {
|
||||
"content": "<extra_id_16>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32084": {
|
||||
"content": "<extra_id_15>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32085": {
|
||||
"content": "<extra_id_14>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32086": {
|
||||
"content": "<extra_id_13>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32087": {
|
||||
"content": "<extra_id_12>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32088": {
|
||||
"content": "<extra_id_11>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32089": {
|
||||
"content": "<extra_id_10>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32090": {
|
||||
"content": "<extra_id_9>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32091": {
|
||||
"content": "<extra_id_8>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32092": {
|
||||
"content": "<extra_id_7>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32093": {
|
||||
"content": "<extra_id_6>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32094": {
|
||||
"content": "<extra_id_5>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32095": {
|
||||
"content": "<extra_id_4>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32096": {
|
||||
"content": "<extra_id_3>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32097": {
|
||||
"content": "<extra_id_2>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32098": {
|
||||
"content": "<extra_id_1>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32099": {
|
||||
"content": "<extra_id_0>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
"<extra_id_0>",
|
||||
"<extra_id_1>",
|
||||
"<extra_id_2>",
|
||||
"<extra_id_3>",
|
||||
"<extra_id_4>",
|
||||
"<extra_id_5>",
|
||||
"<extra_id_6>",
|
||||
"<extra_id_7>",
|
||||
"<extra_id_8>",
|
||||
"<extra_id_9>",
|
||||
"<extra_id_10>",
|
||||
"<extra_id_11>",
|
||||
"<extra_id_12>",
|
||||
"<extra_id_13>",
|
||||
"<extra_id_14>",
|
||||
"<extra_id_15>",
|
||||
"<extra_id_16>",
|
||||
"<extra_id_17>",
|
||||
"<extra_id_18>",
|
||||
"<extra_id_19>",
|
||||
"<extra_id_20>",
|
||||
"<extra_id_21>",
|
||||
"<extra_id_22>",
|
||||
"<extra_id_23>",
|
||||
"<extra_id_24>",
|
||||
"<extra_id_25>",
|
||||
"<extra_id_26>",
|
||||
"<extra_id_27>",
|
||||
"<extra_id_28>",
|
||||
"<extra_id_29>",
|
||||
"<extra_id_30>",
|
||||
"<extra_id_31>",
|
||||
"<extra_id_32>",
|
||||
"<extra_id_33>",
|
||||
"<extra_id_34>",
|
||||
"<extra_id_35>",
|
||||
"<extra_id_36>",
|
||||
"<extra_id_37>",
|
||||
"<extra_id_38>",
|
||||
"<extra_id_39>",
|
||||
"<extra_id_40>",
|
||||
"<extra_id_41>",
|
||||
"<extra_id_42>",
|
||||
"<extra_id_43>",
|
||||
"<extra_id_44>",
|
||||
"<extra_id_45>",
|
||||
"<extra_id_46>",
|
||||
"<extra_id_47>",
|
||||
"<extra_id_48>",
|
||||
"<extra_id_49>",
|
||||
"<extra_id_50>",
|
||||
"<extra_id_51>",
|
||||
"<extra_id_52>",
|
||||
"<extra_id_53>",
|
||||
"<extra_id_54>",
|
||||
"<extra_id_55>",
|
||||
"<extra_id_56>",
|
||||
"<extra_id_57>",
|
||||
"<extra_id_58>",
|
||||
"<extra_id_59>",
|
||||
"<extra_id_60>",
|
||||
"<extra_id_61>",
|
||||
"<extra_id_62>",
|
||||
"<extra_id_63>",
|
||||
"<extra_id_64>",
|
||||
"<extra_id_65>",
|
||||
"<extra_id_66>",
|
||||
"<extra_id_67>",
|
||||
"<extra_id_68>",
|
||||
"<extra_id_69>",
|
||||
"<extra_id_70>",
|
||||
"<extra_id_71>",
|
||||
"<extra_id_72>",
|
||||
"<extra_id_73>",
|
||||
"<extra_id_74>",
|
||||
"<extra_id_75>",
|
||||
"<extra_id_76>",
|
||||
"<extra_id_77>",
|
||||
"<extra_id_78>",
|
||||
"<extra_id_79>",
|
||||
"<extra_id_80>",
|
||||
"<extra_id_81>",
|
||||
"<extra_id_82>",
|
||||
"<extra_id_83>",
|
||||
"<extra_id_84>",
|
||||
"<extra_id_85>",
|
||||
"<extra_id_86>",
|
||||
"<extra_id_87>",
|
||||
"<extra_id_88>",
|
||||
"<extra_id_89>",
|
||||
"<extra_id_90>",
|
||||
"<extra_id_91>",
|
||||
"<extra_id_92>",
|
||||
"<extra_id_93>",
|
||||
"<extra_id_94>",
|
||||
"<extra_id_95>",
|
||||
"<extra_id_96>",
|
||||
"<extra_id_97>",
|
||||
"<extra_id_98>",
|
||||
"<extra_id_99>"
|
||||
],
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"eos_token": "</s>",
|
||||
"extra_ids": 100,
|
||||
"legacy": true,
|
||||
"model_max_length": 512,
|
||||
"pad_token": "<pad>",
|
||||
"sp_model_kwargs": {},
|
||||
"tokenizer_class": "T5Tokenizer",
|
||||
"unk_token": "<unk>"
|
||||
}
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"_class_name": "ChromaTransformer2DModel",
|
||||
"_diffusers_version": "0.34.0.dev0",
|
||||
"approximator_hidden_dim": 5120,
|
||||
"approximator_in_factor": 16,
|
||||
"approximator_layers": 5,
|
||||
"attention_head_dim": 128,
|
||||
"axes_dims_rope": [
|
||||
16,
|
||||
56,
|
||||
56
|
||||
],
|
||||
"in_channels": 64,
|
||||
"joint_attention_dim": 4096,
|
||||
"num_attention_heads": 24,
|
||||
"num_layers": 19,
|
||||
"num_single_layers": 38,
|
||||
"out_channels": null,
|
||||
"patch_size": 1
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,37 @@
|
|||
{
|
||||
"_class_name": "AutoencoderKL",
|
||||
"_diffusers_version": "0.34.0.dev0",
|
||||
"act_fn": "silu",
|
||||
"block_out_channels": [
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
512
|
||||
],
|
||||
"down_block_types": [
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D"
|
||||
],
|
||||
"force_upcast": true,
|
||||
"in_channels": 3,
|
||||
"latent_channels": 16,
|
||||
"latents_mean": null,
|
||||
"latents_std": null,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_add_attention": true,
|
||||
"norm_num_groups": 32,
|
||||
"out_channels": 3,
|
||||
"sample_size": 1024,
|
||||
"scaling_factor": 0.3611,
|
||||
"shift_factor": 0.1159,
|
||||
"up_block_types": [
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D"
|
||||
],
|
||||
"use_post_quant_conv": false,
|
||||
"use_quant_conv": false
|
||||
}
|
||||
|
|
@ -107,6 +107,9 @@ def make_meta(fn, maxrank, rank_ratio):
|
|||
elif shared.sd_model_type == "f1":
|
||||
meta["model_spec.architecture"] = "flux-1-dev/lora"
|
||||
meta["ss_base_model_version"] = "flux1"
|
||||
elif shared.sd_model_type == "chroma":
|
||||
meta["model_spec.architecture"] = "chroma/lora"
|
||||
meta["ss_base_model_version"] = "chroma"
|
||||
elif shared.sd_model_type == "sc":
|
||||
meta["model_spec.architecture"] = "stable-cascade-v1-prior/lora"
|
||||
return meta
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ def load_network(name, network_on_disk) -> network.Network:
|
|||
net = network.Network(name, network_on_disk)
|
||||
net.mtime = os.path.getmtime(network_on_disk.filename)
|
||||
sd = sd_models.read_state_dict(network_on_disk.filename, what='network')
|
||||
if shared.sd_model_type == 'f1': # if kohya flux lora, convert state_dict
|
||||
if shared.sd_model_type in ['f1', 'chroma']: # if kohya flux lora, convert state_dict
|
||||
sd = lora_convert._convert_kohya_flux_lora_to_diffusers(sd) or sd # pylint: disable=protected-access
|
||||
assign_network_names_to_compvis_modules(shared.sd_model)
|
||||
keys_failed_to_match = {}
|
||||
|
|
|
|||
|
|
@ -546,7 +546,7 @@ def check_diffusers():
|
|||
t_start = time.time()
|
||||
if args.skip_all or args.skip_git or args.experimental:
|
||||
return
|
||||
sha = '8adc6003ba4dbf5b61bb4f1ce571e9e55e145a99' # diffusers commit hash
|
||||
sha = '900653c814cfa431877ae46548fe496506bda2ad' # diffusers commit hash
|
||||
pkg = pkg_resources.working_set.by_key.get('diffusers', None)
|
||||
minor = int(pkg.version.split('.')[1] if pkg is not None else 0)
|
||||
cur = opts.get('diffusers_version', '') if minor > 0 else ''
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ def map_model_name(name: str):
|
|||
return 'xl'
|
||||
if name == 'sd3':
|
||||
return 'v3'
|
||||
if name == 'f1':
|
||||
if name in ['f1', 'chroma']:
|
||||
return 'fx'
|
||||
return name
|
||||
|
||||
|
|
|
|||
|
|
@ -107,6 +107,9 @@ def make_meta(fn, maxrank, rank_ratio):
|
|||
elif shared.sd_model_type == "f1":
|
||||
meta["model_spec.architecture"] = "flux-1-dev/lora"
|
||||
meta["ss_base_model_version"] = "flux1"
|
||||
elif shared.sd_model_type == "chroma":
|
||||
meta["model_spec.architecture"] = "chroma/lora"
|
||||
meta["ss_base_model_version"] = "chroma"
|
||||
elif shared.sd_model_type == "sc":
|
||||
meta["model_spec.architecture"] = "stable-cascade-v1-prior/lora"
|
||||
return meta
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ def load_safetensors(name, network_on_disk) -> Union[network.Network, None]:
|
|||
net = network.Network(name, network_on_disk)
|
||||
net.mtime = os.path.getmtime(network_on_disk.filename)
|
||||
state_dict = sd_models.read_state_dict(network_on_disk.filename, what='network')
|
||||
if shared.sd_model_type == 'f1': # if kohya flux lora, convert state_dict
|
||||
if shared.sd_model_type in ['f1', 'chroma']: # if kohya flux lora, convert state_dict
|
||||
state_dict = lora_convert._convert_kohya_flux_lora_to_diffusers(state_dict) or state_dict # pylint: disable=protected-access
|
||||
if shared.sd_model_type == 'sd3': # if kohya flux lora, convert state_dict
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ class SdVersion(enum.Enum):
|
|||
SC = 5
|
||||
F1 = 6
|
||||
HV = 7
|
||||
CHROMA = 8
|
||||
|
||||
|
||||
class NetworkOnDisk:
|
||||
|
|
@ -59,6 +60,8 @@ class NetworkOnDisk:
|
|||
return 'f1'
|
||||
if base.startswith("hunyuan_video"):
|
||||
return 'hv'
|
||||
if base.startswith("chroma"):
|
||||
return 'chroma'
|
||||
|
||||
if arch.startswith("stable-diffusion-v1"):
|
||||
return 'sd1'
|
||||
|
|
@ -70,6 +73,8 @@ class NetworkOnDisk:
|
|||
return 'f1'
|
||||
if arch.startswith("hunyuan-video"):
|
||||
return 'hv'
|
||||
if arch.startswith("chroma"):
|
||||
return 'chroma'
|
||||
|
||||
if "v1-5" in str(self.metadata.get('ss_sd_model_name', "")):
|
||||
return 'sd1'
|
||||
|
|
@ -79,6 +84,8 @@ class NetworkOnDisk:
|
|||
return 'f1'
|
||||
if 'xl' in self.name.lower():
|
||||
return 'xl'
|
||||
if 'chroma' in self.name.lower():
|
||||
return 'chroma'
|
||||
|
||||
return ''
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,339 @@
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
import diffusers
|
||||
import transformers
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import hf_hub_download, auth_check
|
||||
from modules import shared, errors, devices, modelloader, sd_models, sd_unet, model_te, model_quant, sd_hijack_te
|
||||
|
||||
|
||||
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
|
||||
|
||||
|
||||
def load_chroma_quanto(checkpoint_info):
|
||||
transformer, text_encoder = None, None
|
||||
quanto = model_quant.load_quanto('Load model: type=Chroma')
|
||||
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
|
||||
debug(f'Load model: type=Chroma quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info.name)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
|
||||
dtype = state_dict['context_embedder.bias'].dtype
|
||||
with torch.device("meta"):
|
||||
transformer = diffusers.ChromaTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
|
||||
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
if shared.opts.diffusers_eval:
|
||||
transformer.eval()
|
||||
transformer_dtype = transformer.dtype
|
||||
if transformer_dtype != devices.dtype:
|
||||
try:
|
||||
transformer = transformer.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
shared.log.error(f"Load model: type=Chroma Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=Chroma failed to load Quanto transformer: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'Chroma Quanto:')
|
||||
|
||||
try:
|
||||
quantization_map = os.path.join(repo_path, "text_encoder", "quantization_map.json")
|
||||
debug(f'Load model: type=Chroma quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder"')
|
||||
if not os.path.exists(quantization_map):
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info.name)
|
||||
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
|
||||
with open(quantization_map, "r", encoding='utf8') as f:
|
||||
quantization_map = json.load(f)
|
||||
with open(os.path.join(repo_path, "text_encoder", "config.json"), encoding='utf8') as f:
|
||||
t5_config = transformers.T5Config(**json.load(f))
|
||||
state_dict = load_file(os.path.join(repo_path, "text_encoder", "model.safetensors"))
|
||||
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
|
||||
with torch.device("meta"):
|
||||
text_encoder = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
|
||||
quanto.requantize(text_encoder, state_dict, quantization_map, device=torch.device("cpu"))
|
||||
if shared.opts.diffusers_eval:
|
||||
text_encoder.eval()
|
||||
text_encoder_dtype = text_encoder.dtype
|
||||
if text_encoder_dtype != devices.dtype:
|
||||
try:
|
||||
text_encoder = text_encoder.to(dtype=devices.dtype)
|
||||
except Exception:
|
||||
shared.log.error(f"Load model: type=Chroma Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_dtype}")
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=Chroma failed to load Quanto text encoder: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'Chroma Quanto:')
|
||||
|
||||
return transformer, text_encoder
|
||||
|
||||
|
||||
def load_chroma_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
|
||||
transformer, text_encoder = None, None
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
model_quant.load_bnb('Load model: type=Chroma')
|
||||
quant = model_quant.get_quant(repo_path)
|
||||
try:
|
||||
# we ignore the distilled guidance layer because it degrades quality too much
|
||||
# see: https://github.com/huggingface/diffusers/pull/11698#issuecomment-2969717180 for more details
|
||||
if quant == 'fp8':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["distilled_guidance_layer"], bnb_4bit_compute_dtype=devices.dtype)
|
||||
debug(f'Quantization: {quantization_config}')
|
||||
transformer = diffusers.ChromaTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'fp4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, llm_int8_skip_modules=["distilled_guidance_layer"], bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
|
||||
debug(f'Quantization: {quantization_config}')
|
||||
transformer = diffusers.ChromaTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
elif quant == 'nf4':
|
||||
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, llm_int8_skip_modules=["distilled_guidance_layer"], bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
|
||||
debug(f'Quantization: {quantization_config}')
|
||||
transformer = diffusers.ChromaTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
|
||||
else:
|
||||
transformer = diffusers.ChromaTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=Chroma failed to load BnB transformer: {e}")
|
||||
transformer, text_encoder = None, None
|
||||
if debug:
|
||||
errors.display(e, 'Chroma:')
|
||||
return transformer, text_encoder
|
||||
|
||||
|
||||
def load_quants(kwargs, pretrained_model_name_or_path, cache_dir, allow_quant):
|
||||
try:
|
||||
if 'transformer' not in kwargs and model_quant.check_nunchaku('Transformer'):
|
||||
raise NotImplementedError('Nunchaku does not support Chroma Model yet. See https://github.com/mit-han-lab/nunchaku/issues/167')
|
||||
elif 'transformer' not in kwargs and model_quant.check_quant('Transformer'):
|
||||
quant_args = model_quant.create_config(allow=allow_quant, module='Transformer')
|
||||
if quant_args:
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
kwargs['transformer'] = diffusers.ChromaTransformer2DModel.from_single_file(pretrained_model_name_or_path, cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args)
|
||||
pass
|
||||
else:
|
||||
kwargs['transformer'] = diffusers.ChromaTransformer2DModel.from_pretrained(pretrained_model_name_or_path, subfolder="transformer", cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args)
|
||||
if 'text_encoder' not in kwargs and model_quant.check_nunchaku('TE'):
|
||||
import nunchaku
|
||||
nunchaku_precision = nunchaku.utils.get_precision()
|
||||
nunchaku_repo = 'mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
|
||||
shared.log.debug(f'Load module: quant=Nunchaku module=t5 repo="{nunchaku_repo}" precision={nunchaku_precision}')
|
||||
kwargs['text_encoder'] = nunchaku.NunchakuT5EncoderModel.from_pretrained(nunchaku_repo, torch_dtype=devices.dtype)
|
||||
elif 'text_encoder' not in kwargs and model_quant.check_quant('TE'):
|
||||
quant_args = model_quant.create_config(allow=allow_quant, module='TE')
|
||||
if quant_args:
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
kwargs['text_encoder'] = transformers.T5EncoderModel.from_single_file(pretrained_model_name_or_path, cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args)
|
||||
else:
|
||||
kwargs['text_encoder'] = transformers.T5EncoderModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args)
|
||||
except Exception as e:
|
||||
shared.log.error(f'Quantization: {e}')
|
||||
errors.display(e, 'Quantization:')
|
||||
return kwargs
|
||||
|
||||
|
||||
def load_transformer(file_path): # triggered by opts.sd_unet change
|
||||
if file_path is None or not os.path.exists(file_path):
|
||||
return None
|
||||
transformer = None
|
||||
quant = model_quant.get_quant(file_path)
|
||||
diffusers_load_config = {
|
||||
"low_cpu_mem_usage": True,
|
||||
"torch_dtype": devices.dtype,
|
||||
"cache_dir": shared.opts.hfcache_dir,
|
||||
}
|
||||
if quant is not None and quant != 'none':
|
||||
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} prequant={quant} dtype={devices.dtype}')
|
||||
if 'gguf' in file_path.lower():
|
||||
from modules import ggml
|
||||
_transformer = ggml.load_gguf(file_path, cls=diffusers.ChromaTransformer2DModel, compute_dtype=devices.dtype)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif quant == 'qint8' or quant == 'qint4':
|
||||
_transformer, _text_encoder = load_chroma_quanto(file_path)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif quant == 'fp8' or quant == 'fp4' or quant == 'nf4':
|
||||
_transformer, _text_encoder = load_chroma_bnb(file_path, diffusers_load_config)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
elif 'nf4' in quant: # TODO chroma: loader for civitai nf4 models
|
||||
from modules.model_chroma_nf4 import load_chroma_nf4
|
||||
_transformer, _text_encoder = load_chroma_nf4(file_path, prequantized=True)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
else:
|
||||
quant_args = model_quant.create_bnb_config({})
|
||||
if quant_args:
|
||||
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}')
|
||||
from modules.model_chroma_nf4 import load_chroma_nf4
|
||||
transformer, _text_encoder = load_chroma_nf4(file_path, prequantized=False)
|
||||
if transformer is not None:
|
||||
return transformer
|
||||
quant_args = model_quant.create_config(module='Transformer')
|
||||
if quant_args:
|
||||
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=torchao dtype={devices.dtype}')
|
||||
transformer = diffusers.ChromaTransformer2DModel.from_single_file(file_path, **diffusers_load_config, **quant_args)
|
||||
if transformer is not None:
|
||||
return transformer
|
||||
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=none dtype={devices.dtype}')
|
||||
# TODO chroma transformer from-single-file with quant
|
||||
# shared.log.warning('Load module: type=UNet/Transformer does not support load-time quantization')
|
||||
transformer = diffusers.ChromaTransformer2DModel.from_single_file(file_path, **diffusers_load_config)
|
||||
if transformer is None:
|
||||
shared.log.error('Failed to load UNet model')
|
||||
shared.opts.sd_unet = 'Default'
|
||||
return transformer
|
||||
|
||||
|
||||
def load_chroma(checkpoint_info, diffusers_load_config): # triggered by opts.sd_checkpoint change
|
||||
fn = checkpoint_info.path
|
||||
repo_id = sd_models.path_to_repo(checkpoint_info.name)
|
||||
login = modelloader.hf_login()
|
||||
try:
|
||||
auth_check(repo_id)
|
||||
except Exception as e:
|
||||
repo_id = None
|
||||
if not os.path.exists(fn):
|
||||
shared.log.error(f'Load model: repo="{repo_id}" login={login} {e}')
|
||||
return None
|
||||
|
||||
prequantized = model_quant.get_quant(checkpoint_info.path)
|
||||
shared.log.debug(f'Load model: type=Chroma model="{checkpoint_info.name}" repo={repo_id or "none"} unet="{shared.opts.sd_unet}" te="{shared.opts.sd_text_encoder}" vae="{shared.opts.sd_vae}" quant={prequantized} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype}')
|
||||
debug(f'Load model: type=Chroma config={diffusers_load_config}')
|
||||
|
||||
transformer = None
|
||||
text_encoder = None
|
||||
vae = None
|
||||
|
||||
# unload current model
|
||||
sd_models.unload_model_weights()
|
||||
shared.sd_model = None
|
||||
devices.torch_gc(force=True)
|
||||
|
||||
if shared.opts.teacache_enabled:
|
||||
from modules import teacache
|
||||
shared.log.debug(f'Transformers cache: type=teacache patch=forward cls={diffusers.ChromaTransformer2DModel.__name__}')
|
||||
diffusers.ChromaTransformer2DModel.forward = teacache.teacache_chroma_forward # patch must be done before transformer is loaded
|
||||
|
||||
# load overrides if any
|
||||
if shared.opts.sd_unet != 'Default':
|
||||
try:
|
||||
debug(f'Load model: type=Chroma unet="{shared.opts.sd_unet}"')
|
||||
transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet])
|
||||
if transformer is None:
|
||||
shared.opts.sd_unet = 'Default'
|
||||
sd_unet.failed_unet.append(shared.opts.sd_unet)
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=Chroma failed to load UNet: {e}")
|
||||
shared.opts.sd_unet = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'Chroma UNet:')
|
||||
if shared.opts.sd_text_encoder != 'Default':
|
||||
try:
|
||||
debug(f'Load model: type=Chroma te="{shared.opts.sd_text_encoder}"')
|
||||
from modules.model_te import load_t5
|
||||
text_encoder = load_t5(name=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir)
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=Chroma failed to load T5: {e}")
|
||||
shared.opts.sd_text_encoder = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'Chroma T5:')
|
||||
if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic':
|
||||
try:
|
||||
debug(f'Load model: type=Chroma vae="{shared.opts.sd_vae}"')
|
||||
from modules import sd_vae
|
||||
# vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override')
|
||||
vae_file = sd_vae.vae_dict[shared.opts.sd_vae]
|
||||
if os.path.exists(vae_file):
|
||||
vae_config = os.path.join('configs', 'chroma', 'vae', 'config.json')
|
||||
vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config)
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=Chroma failed to load VAE: {e}")
|
||||
shared.opts.sd_vae = 'Default'
|
||||
if debug:
|
||||
errors.display(e, 'Chroma VAE:')
|
||||
|
||||
# load quantized components if any
|
||||
if prequantized == 'nf4':
|
||||
try:
|
||||
from modules.model_chroma_nf4 import load_chroma_nf4
|
||||
_transformer, _text_encoder = load_chroma_nf4(checkpoint_info)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
if _text_encoder is not None:
|
||||
text_encoder = _text_encoder
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=Chroma failed to load NF4 components: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'Chroma NF4:')
|
||||
if prequantized == 'qint8' or prequantized == 'qint4':
|
||||
try:
|
||||
_transformer, _text_encoder = load_chroma_quanto(checkpoint_info)
|
||||
if _transformer is not None:
|
||||
transformer = _transformer
|
||||
if _text_encoder is not None:
|
||||
text_encoder = _text_encoder
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=Chroma failed to load Quanto components: {e}")
|
||||
if debug:
|
||||
errors.display(e, 'Chroma Quanto:')
|
||||
|
||||
# initialize pipeline with pre-loaded components
|
||||
kwargs = {}
|
||||
if transformer is not None:
|
||||
kwargs['transformer'] = transformer
|
||||
sd_unet.loaded_unet = shared.opts.sd_unet
|
||||
if text_encoder is not None:
|
||||
kwargs['text_encoder'] = text_encoder
|
||||
model_te.loaded_te = shared.opts.sd_text_encoder
|
||||
if vae is not None:
|
||||
kwargs['vae'] = vae
|
||||
|
||||
# Todo: atm only ChromaPipeline is implemented in diffusers.
|
||||
# Need to add ChromaFillPipeline, ChromaControlPipeline, ChromaImg2ImgPipeline etc when available.
|
||||
# Chroma will support inpainting *after* its training has finished:
|
||||
# https://huggingface.co/lodestones/Chroma/discussions/28#6826dd2ed86f53ff983add5c
|
||||
cls = diffusers.ChromaPipeline
|
||||
shared.log.debug(f'Load model: type=Chroma cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}')
|
||||
for c in kwargs:
|
||||
if getattr(kwargs[c], 'quantization_method', None) is not None or getattr(kwargs[c], 'gguf', None) is not None:
|
||||
shared.log.debug(f'Load model: type=Chroma component={c} dtype={kwargs[c].dtype} quant={getattr(kwargs[c], "quantization_method", None) or getattr(kwargs[c], "gguf", None)}')
|
||||
if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32:
|
||||
try:
|
||||
kwargs[c] = kwargs[c].to(dtype=devices.dtype)
|
||||
shared.log.warning(f'Load model: type=Chroma component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
allow_quant = 'gguf' not in (sd_unet.loaded_unet or '') and (prequantized is None or prequantized == 'none')
|
||||
if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)):
|
||||
kwargs = load_quants(kwargs, repo_id or fn, cache_dir=shared.opts.diffusers_dir, allow_quant=allow_quant)
|
||||
# kwargs = model_quant.create_config(kwargs, allow_quant)
|
||||
if fn.endswith('.safetensors') and os.path.isfile(fn):
|
||||
pipe = diffusers.ChromaPipeline.from_single_file(fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
|
||||
else:
|
||||
pipe = cls.from_pretrained(repo_id or fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
|
||||
|
||||
if shared.opts.teacache_enabled and model_quant.check_nunchaku('Transformer'):
|
||||
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
|
||||
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
|
||||
|
||||
# release memory
|
||||
transformer = None
|
||||
text_encoder = None
|
||||
vae = None
|
||||
for k in kwargs.keys():
|
||||
kwargs[k] = None
|
||||
sd_hijack_te.init_hijack(pipe)
|
||||
devices.torch_gc(force=True)
|
||||
return pipe
|
||||
|
|
@ -0,0 +1,202 @@
|
|||
"""
|
||||
Copied from: https://github.com/huggingface/diffusers/issues/9165
|
||||
+ adjusted for Chroma by skipping distilled guidance layer
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.quantizers.quantizers_utils import get_module_from_name
|
||||
from huggingface_hub import hf_hub_download
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from diffusers.loaders.single_file_utils import convert_chroma_transformer_checkpoint_to_diffusers
|
||||
import safetensors.torch
|
||||
from modules import shared, devices, model_quant
|
||||
|
||||
|
||||
debug = os.environ.get('SD_LOAD_DEBUG', None) is not None
|
||||
|
||||
|
||||
def _replace_with_bnb_linear(
|
||||
model,
|
||||
method="nf4",
|
||||
has_been_replaced=False,
|
||||
):
|
||||
"""
|
||||
Private method that wraps the recursion for module replacement.
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||
"""
|
||||
bnb = model_quant.load_bnb('Load model: type=Chroma')
|
||||
for name, module in model.named_children():
|
||||
# we ignore the distilled guidance layer because it degrades quality too much
|
||||
# see: https://github.com/huggingface/diffusers/pull/11698#issuecomment-2969717180 for more details
|
||||
if "distilled_guidance_layer" in name:
|
||||
continue
|
||||
if isinstance(module, nn.Linear):
|
||||
with init_empty_weights():
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
|
||||
if method == "llm_int8":
|
||||
model._modules[name] = bnb.nn.Linear8bitLt( # pylint: disable=protected-access
|
||||
in_features,
|
||||
out_features,
|
||||
module.bias is not None,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
)
|
||||
has_been_replaced = True
|
||||
else:
|
||||
model._modules[name] = bnb.nn.Linear4bit( # pylint: disable=protected-access
|
||||
in_features,
|
||||
out_features,
|
||||
module.bias is not None,
|
||||
compute_dtype=devices.dtype,
|
||||
compress_statistics=False,
|
||||
quant_type="nf4",
|
||||
)
|
||||
has_been_replaced = True
|
||||
# Store the module class in case we need to transpose the weight later
|
||||
model._modules[name].source_cls = type(module) # pylint: disable=protected-access
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False) # pylint: disable=protected-access
|
||||
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_bnb_linear(
|
||||
module,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def check_quantized_param(
|
||||
model,
|
||||
param_name: str,
|
||||
) -> bool:
|
||||
bnb = model_quant.load_bnb('Load model: type=Chroma')
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit): # pylint: disable=protected-access
|
||||
# Add here check for loaded components' dtypes once serialization is implemented
|
||||
return True
|
||||
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
|
||||
# bias could be loaded by regular set_module_tensor_to_device() from accelerate,
|
||||
# but it would wrongly use uninitialized weight there.
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def create_quantized_param(
|
||||
model,
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict=None,
|
||||
unexpected_keys=None,
|
||||
pre_quantized=False
|
||||
):
|
||||
bnb = model_quant.load_bnb('Load model: type=Chroma')
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
if tensor_name not in module._parameters: # pylint: disable=protected-access
|
||||
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
|
||||
|
||||
old_value = getattr(module, tensor_name)
|
||||
|
||||
if tensor_name == "bias":
|
||||
if param_value is None:
|
||||
new_value = old_value.to(target_device)
|
||||
else:
|
||||
new_value = param_value.to(target_device)
|
||||
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
|
||||
module._parameters[tensor_name] = new_value # pylint: disable=protected-access
|
||||
return
|
||||
|
||||
if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): # pylint: disable=protected-access
|
||||
raise ValueError("this function only loads `Linear4bit components`")
|
||||
if (
|
||||
old_value.device == torch.device("meta")
|
||||
and target_device not in ["meta", torch.device("meta")]
|
||||
and param_value is None
|
||||
):
|
||||
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
|
||||
|
||||
if pre_quantized:
|
||||
if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (param_name + ".quant_state.bitsandbytes__nf4" not in state_dict):
|
||||
raise ValueError(f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components.")
|
||||
quantized_stats = {}
|
||||
for k, v in state_dict.items():
|
||||
# `startswith` to counter for edge cases where `param_name`
|
||||
# substring can be present in multiple places in the `state_dict`
|
||||
if param_name + "." in k and k.startswith(param_name):
|
||||
quantized_stats[k] = v
|
||||
if unexpected_keys is not None and k in unexpected_keys:
|
||||
unexpected_keys.remove(k)
|
||||
new_value = bnb.nn.Params4bit.from_prequantized(
|
||||
data=param_value,
|
||||
quantized_stats=quantized_stats,
|
||||
requires_grad=False,
|
||||
device=target_device,
|
||||
)
|
||||
else:
|
||||
new_value = param_value.to("cpu")
|
||||
kwargs = old_value.__dict__
|
||||
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
|
||||
module._parameters[tensor_name] = new_value # pylint: disable=protected-access
|
||||
|
||||
|
||||
def load_chroma_nf4(checkpoint_info, prequantized: bool = True):
|
||||
transformer = None
|
||||
text_encoder = None
|
||||
if isinstance(checkpoint_info, str):
|
||||
repo_path = checkpoint_info
|
||||
else:
|
||||
repo_path = checkpoint_info.path
|
||||
if os.path.exists(repo_path) and os.path.isfile(repo_path):
|
||||
ckpt_path = repo_path
|
||||
elif os.path.exists(repo_path) and os.path.isdir(repo_path) and os.path.exists(os.path.join(repo_path, "diffusion_pytorch_model.safetensors")):
|
||||
ckpt_path = os.path.join(repo_path, "diffusion_pytorch_model.safetensors")
|
||||
else:
|
||||
ckpt_path = hf_hub_download(repo_path, filename="diffusion_pytorch_model.safetensors", cache_dir=shared.opts.diffusers_dir)
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
|
||||
try:
|
||||
converted_state_dict = convert_chroma_transformer_checkpoint_to_diffusers(original_state_dict)
|
||||
except Exception as e:
|
||||
shared.log.error(f"Load model: type=Chroma Failed to convert UNET: {e}")
|
||||
if debug:
|
||||
from modules import errors
|
||||
errors.display(e, 'Chroma convert:')
|
||||
converted_state_dict = original_state_dict
|
||||
|
||||
with init_empty_weights():
|
||||
from diffusers import ChromaTransformer2DModel
|
||||
config = ChromaTransformer2DModel.load_config(os.path.join('configs', 'chroma'), subfolder="transformer")
|
||||
transformer = ChromaTransformer2DModel.from_config(config).to(devices.dtype)
|
||||
expected_state_dict_keys = list(transformer.state_dict().keys())
|
||||
|
||||
_replace_with_bnb_linear(transformer, "nf4")
|
||||
|
||||
try:
|
||||
for param_name, param in converted_state_dict.items():
|
||||
if param_name not in expected_state_dict_keys:
|
||||
continue
|
||||
is_param_float8_e4m3fn = hasattr(torch, "float8_e4m3fn") and param.dtype == torch.float8_e4m3fn
|
||||
if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
|
||||
param = param.to(devices.dtype)
|
||||
if not check_quantized_param(transformer, param_name):
|
||||
set_module_tensor_to_device(transformer, param_name, device=0, value=param)
|
||||
else:
|
||||
create_quantized_param(transformer, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=prequantized)
|
||||
except Exception as e:
|
||||
transformer, text_encoder = None
|
||||
shared.log.error(f"Load model: type=Chroma failed to load UNET: {e}")
|
||||
if debug:
|
||||
from modules import errors
|
||||
errors.display(e, 'Chroma:')
|
||||
|
||||
del original_state_dict
|
||||
devices.torch_gc(force=True)
|
||||
return transformer, text_encoder
|
||||
|
|
@ -427,8 +427,12 @@ def optimum_quanto_model(model, op=None, sd_model=None, weights=None, activation
|
|||
from modules import devices, shared
|
||||
quanto = load_quanto('Quantize model: type=Optimum Quanto')
|
||||
global quant_last_model_name, quant_last_model_device # pylint: disable=global-statement
|
||||
if sd_model is not None and "Flux" in sd_model.__class__.__name__: # LayerNorm is not supported
|
||||
if sd_model is not None and "Flux" in sd_model.__class__.__name__ or "Chroma" in sd_model.__class__.__name__: # LayerNorm is not supported
|
||||
exclude_list = ["transformer_blocks.*.norm1.norm", "transformer_blocks.*.norm2", "transformer_blocks.*.norm1_context.norm", "transformer_blocks.*.norm2_context", "single_transformer_blocks.*.norm.norm", "norm_out.norm"]
|
||||
if "Chroma" in sd_model.__class__.__name__:
|
||||
# we ignore the distilled guidance layer because it degrades quality too much
|
||||
# see: https://github.com/huggingface/diffusers/pull/11698#issuecomment-2969717180 for more details
|
||||
exclude_list.append("distilled_guidance_layer.*")
|
||||
else:
|
||||
exclude_list = None
|
||||
weights = getattr(quanto, weights) if weights is not None else getattr(quanto, shared.opts.optimum_quanto_weights_type)
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ def get_model_type(pipe):
|
|||
model_type = 'sc'
|
||||
elif "AuraFlow" in name:
|
||||
model_type = 'auraflow'
|
||||
elif 'Chroma' in name:
|
||||
model_type = 'chroma'
|
||||
elif "Flux" in name or "Flex1" in name or "Flex2" in name:
|
||||
model_type = 'f1'
|
||||
elif "Lumina2" in name:
|
||||
|
|
|
|||
|
|
@ -150,6 +150,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
|
|||
'StableDiffusion' in model.__class__.__name__ or
|
||||
'StableCascade' in model.__class__.__name__ or
|
||||
'Flux' in model.__class__.__name__ or
|
||||
'Chroma' in model.__class__.__name__ or
|
||||
'HiDreamImagePipeline' in model.__class__.__name__ # hidream-e1 has different embeds
|
||||
):
|
||||
try:
|
||||
|
|
@ -174,15 +175,14 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
|
|||
args['prompt_embeds_llama3'] = prompt_embeds[1]
|
||||
elif hasattr(model, 'text_encoder') and hasattr(model, 'tokenizer') and 'prompt_embeds' in possible and prompt_parser_diffusers.embedder is not None:
|
||||
args['prompt_embeds'] = prompt_parser_diffusers.embedder('prompt_embeds')
|
||||
if prompt_parser_diffusers.embedder is not None:
|
||||
if 'StableCascade' in model.__class__.__name__:
|
||||
args['prompt_embeds_pooled'] = prompt_parser_diffusers.embedder('positive_pooleds').unsqueeze(0)
|
||||
elif 'XL' in model.__class__.__name__:
|
||||
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
|
||||
elif 'StableDiffusion3' in model.__class__.__name__:
|
||||
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
|
||||
elif 'Flux' in model.__class__.__name__:
|
||||
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
|
||||
if 'StableCascade' in model.__class__.__name__:
|
||||
args['prompt_embeds_pooled'] = prompt_parser_diffusers.embedder('positive_pooleds').unsqueeze(0)
|
||||
elif 'XL' in model.__class__.__name__:
|
||||
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
|
||||
elif 'StableDiffusion3' in model.__class__.__name__:
|
||||
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
|
||||
elif 'Flux' in model.__class__.__name__:
|
||||
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
|
||||
else:
|
||||
args['prompt'] = prompts
|
||||
if 'negative_prompt' in possible:
|
||||
|
|
@ -193,13 +193,12 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
|
|||
args['negative_prompt_embeds_llama3'] = negative_prompt_embeds[1]
|
||||
elif hasattr(model, 'text_encoder') and hasattr(model, 'tokenizer') and 'negative_prompt_embeds' in possible and prompt_parser_diffusers.embedder is not None:
|
||||
args['negative_prompt_embeds'] = prompt_parser_diffusers.embedder('negative_prompt_embeds')
|
||||
if prompt_parser_diffusers.embedder is not None:
|
||||
if 'StableCascade' in model.__class__.__name__:
|
||||
args['negative_prompt_embeds_pooled'] = prompt_parser_diffusers.embedder('negative_pooleds').unsqueeze(0)
|
||||
elif 'XL' in model.__class__.__name__:
|
||||
args['negative_pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('negative_pooleds')
|
||||
elif 'StableDiffusion3' in model.__class__.__name__:
|
||||
args['negative_pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('negative_pooleds')
|
||||
if 'StableCascade' in model.__class__.__name__:
|
||||
args['negative_prompt_embeds_pooled'] = prompt_parser_diffusers.embedder('negative_pooleds').unsqueeze(0)
|
||||
elif 'XL' in model.__class__.__name__:
|
||||
args['negative_pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('negative_pooleds')
|
||||
elif 'StableDiffusion3' in model.__class__.__name__:
|
||||
args['negative_pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('negative_pooleds')
|
||||
else:
|
||||
if 'PixArtSigmaPipeline' in model.__class__.__name__: # pixart-sigma pipeline throws list-of-list for negative prompt
|
||||
args['negative_prompt'] = negative_prompts[0]
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import torch
|
|||
from compel.embeddings_provider import BaseTextualInversionManager, EmbeddingsProvider
|
||||
from transformers import PreTrainedTokenizer
|
||||
from modules import shared, prompt_parser, devices, sd_models
|
||||
from modules.prompt_parser_xhinker import get_weighted_text_embeddings_sd15, get_weighted_text_embeddings_sdxl_2p, get_weighted_text_embeddings_sd3, get_weighted_text_embeddings_flux1
|
||||
from modules.prompt_parser_xhinker import get_weighted_text_embeddings_sd15, get_weighted_text_embeddings_sdxl_2p, get_weighted_text_embeddings_sd3, get_weighted_text_embeddings_flux1, get_weighted_text_embeddings_chroma
|
||||
|
||||
debug_enabled = os.environ.get('SD_PROMPT_DEBUG', None)
|
||||
debug = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
|
||||
|
|
@ -27,6 +27,7 @@ def prompt_compatible(pipe = None):
|
|||
'DemoFusion' not in pipe.__class__.__name__ and
|
||||
'StableCascade' not in pipe.__class__.__name__ and
|
||||
'Flux' not in pipe.__class__.__name__ and
|
||||
'Chroma' not in pipe.__class__.__name__ and
|
||||
'HiDreamImage' not in pipe.__class__.__name__
|
||||
):
|
||||
shared.log.warning(f"Prompt parser not supported: {pipe.__class__.__name__}")
|
||||
|
|
@ -510,6 +511,10 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c
|
|||
prompt_embeds, pooled_prompt_embeds, _ = pipe.encode_prompt(prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=1)
|
||||
return prompt_embeds, pooled_prompt_embeds, None, None # no negative support
|
||||
|
||||
if "Chroma" in pipe.__class__.__name__: # does not use clip and has no pooled embeds
|
||||
prompt_embeds, _, _, negative_prompt_embeds, _, _ = pipe.encode_prompt(prompt=prompt, negative_prompt=neg_prompt, device=device, num_images_per_prompt=1)
|
||||
return prompt_embeds, None, negative_prompt_embeds, None
|
||||
|
||||
if "HiDreamImage" in pipe.__class__.__name__: # clip is only used for the pooled embeds
|
||||
prompt_embeds_t5, negative_prompt_embeds_t5, prompt_embeds_llama3, negative_prompt_embeds_llama3, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
|
||||
prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3, prompt_4=prompt_4,
|
||||
|
|
@ -662,6 +667,8 @@ def get_xhinker_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", cl
|
|||
prompt_embed, negative_embed, positive_pooled, negative_pooled = get_weighted_text_embeddings_sd3(pipe=pipe, prompt=prompt, neg_prompt=neg_prompt, use_t5_encoder=bool(pipe.text_encoder_3))
|
||||
elif 'Flux' in pipe.__class__.__name__:
|
||||
prompt_embed, positive_pooled = get_weighted_text_embeddings_flux1(pipe=pipe, prompt=prompt, prompt2=prompt_2, device=devices.device)
|
||||
elif 'Chroma' in pipe.__class__.__name__:
|
||||
prompt_embed, negative_embed = get_weighted_text_embeddings_chroma(pipe=pipe, prompt=prompt, neg_prompt=neg_prompt, device=devices.device)
|
||||
elif 'XL' in pipe.__class__.__name__:
|
||||
prompt_embed, negative_embed, positive_pooled, negative_pooled = get_weighted_text_embeddings_sdxl_2p(pipe=pipe, prompt=prompt, prompt_2=prompt_2, neg_prompt=neg_prompt, neg_prompt_2=neg_prompt_2)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -17,11 +17,13 @@
|
|||
## -----------------------------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPTokenizer, T5Tokenizer
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
from diffusers import FluxPipeline
|
||||
from diffusers import ChromaPipeline
|
||||
from modules.prompt_parser import parse_prompt_attention # use built-in A1111 parser
|
||||
|
||||
|
||||
|
|
@ -1424,3 +1426,75 @@ def get_weighted_text_embeddings_flux1(
|
|||
t5_prompt_embeds = t5_prompt_embeds.to(dtype=pipe.text_encoder_2.dtype, device=device)
|
||||
|
||||
return t5_prompt_embeds, prompt_embeds
|
||||
|
||||
|
||||
def get_weighted_text_embeddings_chroma(
|
||||
pipe: ChromaPipeline,
|
||||
prompt: str = "",
|
||||
neg_prompt: str = "",
|
||||
device=None
|
||||
):
|
||||
"""
|
||||
This function can process long prompt with weights for Chroma model
|
||||
|
||||
Args:
|
||||
pipe (ChromaPipeline)
|
||||
prompt (str)
|
||||
neg_prompt (str)
|
||||
device (torch.device, optional): Device to run the embeddings on.
|
||||
Returns:
|
||||
prompt_embeds (T5 prompt embeds as torch.Tensor)
|
||||
neg_prompt_embeds (T5 prompt embeds as torch.Tensor)
|
||||
"""
|
||||
if device is None:
|
||||
device = pipe.text_encoder.device
|
||||
|
||||
# prompt
|
||||
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights_t5(
|
||||
pipe.tokenizer, prompt
|
||||
)
|
||||
|
||||
prompt_tokens = torch.tensor([prompt_tokens], dtype=torch.long)
|
||||
|
||||
t5_prompt_embeds = pipe.text_encoder(prompt_tokens.to(device))[0].squeeze(0)
|
||||
t5_prompt_embeds = t5_prompt_embeds.to(device=device)
|
||||
|
||||
# add weight to t5 prompt
|
||||
for z in range(len(prompt_weights)):
|
||||
if prompt_weights[z] != 1.0:
|
||||
t5_prompt_embeds[z] = t5_prompt_embeds[z] * prompt_weights[z]
|
||||
t5_prompt_embeds = t5_prompt_embeds.unsqueeze(0)
|
||||
t5_prompt_embeds = t5_prompt_embeds.to(dtype=pipe.text_encoder.dtype, device=device)
|
||||
|
||||
# negative prompt
|
||||
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights_t5(
|
||||
pipe.tokenizer, neg_prompt
|
||||
)
|
||||
|
||||
neg_prompt_tokens = torch.tensor([neg_prompt_tokens], dtype=torch.long)
|
||||
|
||||
t5_neg_prompt_embeds = pipe.text_encoder(neg_prompt_tokens.to(device))[0].squeeze(0)
|
||||
t5_neg_prompt_embeds = t5_neg_prompt_embeds.to(device=device)
|
||||
|
||||
# add weight to neg t5 embeddings
|
||||
for z in range(len(neg_prompt_weights)):
|
||||
if neg_prompt_weights[z] != 1.0:
|
||||
t5_neg_prompt_embeds[z] = t5_neg_prompt_embeds[z] * neg_prompt_weights[z]
|
||||
t5_neg_prompt_embeds = t5_neg_prompt_embeds.unsqueeze(0)
|
||||
t5_neg_prompt_embeds = t5_neg_prompt_embeds.to(dtype=pipe.text_encoder.dtype, device=device)
|
||||
|
||||
def pad_prompt_embeds_to_same_size(prompt_embeds_a, prompt_embeds_b):
|
||||
size_a = prompt_embeds_a.size(1)
|
||||
size_b = prompt_embeds_b.size(1)
|
||||
|
||||
if size_a < size_b:
|
||||
pad_size = size_b - size_a
|
||||
prompt_embeds_a = F.pad(prompt_embeds_a, (0, 0, 0, pad_size)) # Pad dim=1
|
||||
elif size_b < size_a:
|
||||
pad_size = size_a - size_b
|
||||
prompt_embeds_b = F.pad(prompt_embeds_b, (0, 0, 0, pad_size)) # Pad dim=1
|
||||
|
||||
return prompt_embeds_a, prompt_embeds_b
|
||||
|
||||
# chroma needs positive and negative prompt embeddings to have the same length (for now)
|
||||
return pad_prompt_embeds_to_same_size(t5_prompt_embeds, t5_neg_prompt_embeds)
|
||||
|
|
|
|||
|
|
@ -92,6 +92,8 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False):
|
|||
guess = 'Stable Diffusion 3'
|
||||
if 'hidream' in f.lower():
|
||||
guess = 'HiDream'
|
||||
if 'chroma' in f.lower():
|
||||
guess = 'Chroma'
|
||||
if 'flux' in f.lower() or 'flex.1' in f.lower() or 'lodestones' in f.lower():
|
||||
guess = 'FLUX'
|
||||
if size > 11000 and size < 16000:
|
||||
|
|
|
|||
|
|
@ -329,6 +329,9 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op='
|
|||
elif model_type in ['FLEX']:
|
||||
from modules.model_flex import load_flex
|
||||
sd_model = load_flex(checkpoint_info, diffusers_load_config)
|
||||
elif model_type in ['Chroma']:
|
||||
from modules.model_chroma import load_chroma
|
||||
sd_model = load_chroma(checkpoint_info, diffusers_load_config)
|
||||
elif model_type in ['Lumina 2']:
|
||||
from modules.model_lumina import load_lumina2
|
||||
sd_model = load_lumina2(checkpoint_info, diffusers_load_config)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from modules.timer import process as process_timer
|
|||
|
||||
debug = os.environ.get('SD_MOVE_DEBUG', None) is not None
|
||||
debug_move = log.trace if debug else lambda *args, **kwargs: None
|
||||
offload_warn = ['sc', 'sd3', 'f1', 'h1', 'hunyuandit', 'auraflow', 'omnigen', 'cogview4']
|
||||
offload_warn = ['sc', 'sd3', 'f1', 'h1', 'hunyuandit', 'auraflow', 'omnigen', 'cogview4', 'chroma']
|
||||
offload_post = ['h1']
|
||||
offload_hook_instance = None
|
||||
balanced_offload_exclude = ['OmniGenPipeline', 'CogView4Pipeline']
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ samplers = all_samplers
|
|||
samplers_for_img2img = all_samplers
|
||||
samplers_map = {}
|
||||
loaded_config = None
|
||||
flow_models = ['Flux', 'StableDiffusion3', 'Lumina', 'AuraFlow', 'Sana', 'CogView4', 'HiDream']
|
||||
flow_models = ['Flux', 'StableDiffusion3', 'Lumina', 'AuraFlow', 'Sana', 'CogView4', 'HiDream', 'Chroma']
|
||||
flow_models += ['Hunyuan', 'LTX', 'Mochi']
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from modules import shared, devices, processing, images, sd_vae_approx, sd_vae_t
|
|||
|
||||
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
|
||||
approximation_indexes = { "Simple": 0, "Approximate": 1, "TAESD": 2, "Full VAE": 3 }
|
||||
flow_models = ['f1', 'sd3', 'lumina', 'auraflow', 'sana', 'lumina2', 'cogview4', 'h1']
|
||||
flow_models = ['f1', 'sd3', 'lumina', 'auraflow', 'sana', 'lumina2', 'cogview4', 'h1', 'chroma']
|
||||
warned = False
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ hf_decode_endpoints = {
|
|||
'sd': 'https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud',
|
||||
'sdxl': 'https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud',
|
||||
'f1': 'https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud',
|
||||
'chroma': 'https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud',
|
||||
'h1': 'https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud',
|
||||
'hunyuanvideo': 'https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud',
|
||||
}
|
||||
|
|
@ -19,6 +20,7 @@ hf_encode_endpoints = {
|
|||
'sd': 'https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud',
|
||||
'sdxl': 'https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud',
|
||||
'f1': 'https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud',
|
||||
'chroma': 'https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud',
|
||||
}
|
||||
dtypes = {
|
||||
"float16": torch.float16,
|
||||
|
|
@ -48,7 +50,7 @@ def remote_decode(latents: torch.Tensor, width: int = 0, height: int = 0, model_
|
|||
params = {}
|
||||
try:
|
||||
latent = latent_copy[i]
|
||||
if model_type != 'f1':
|
||||
if model_type not in ['f1', 'chroma']:
|
||||
latent = latent.unsqueeze(0)
|
||||
params = {
|
||||
"input_tensor_type": "binary",
|
||||
|
|
@ -74,7 +76,7 @@ def remote_decode(latents: torch.Tensor, width: int = 0, height: int = 0, model_
|
|||
params["output_type"] = "pt"
|
||||
params["output_tensor_type"] = "binary"
|
||||
headers["Accept"] = "tensor/binary"
|
||||
if (model_type == 'f1' or model_type == 'h1') and (width > 0) and (height > 0):
|
||||
if (model_type in ['f1', 'h1', 'chroma']) and (width > 0) and (height > 0):
|
||||
params['width'] = width
|
||||
params['height'] = height
|
||||
if shared.sd_model.vae is not None and shared.sd_model.vae.config is not None:
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ def get_model(model_type = 'decoder', variant = None):
|
|||
cls = shared.sd_model_type
|
||||
if cls in {'ldm', 'pixartalpha'}:
|
||||
cls = 'sd'
|
||||
elif cls in {'h1', 'lumina2'}:
|
||||
elif cls in {'h1', 'lumina2', 'chroma'}:
|
||||
cls = 'f1'
|
||||
elif cls == 'pixartsigma':
|
||||
cls = 'sdxl'
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ pipelines = {
|
|||
'DeepFloyd IF': getattr(diffusers, 'IFPipeline', None),
|
||||
'FLUX': getattr(diffusers, 'FluxPipeline', None),
|
||||
'FLEX': getattr(diffusers, 'AutoPipelineForText2Image', None),
|
||||
'Chroma': getattr(diffusers, 'ChromaPipeline', None),
|
||||
'Sana': getattr(diffusers, 'SanaPipeline', None),
|
||||
'Lumina-Next': getattr(diffusers, 'LuminaText2ImgPipeline', None),
|
||||
'Lumina 2': getattr(diffusers, 'Lumina2Pipeline', None),
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ from .teacache_lumina2 import teacache_lumina2_forward
|
|||
from .teacache_ltx import teacache_ltx_forward
|
||||
from .teacache_mochi import teacache_mochi_forward
|
||||
from .teacache_cogvideox import teacache_cog_forward
|
||||
from .teacache_chroma import teacache_chroma_forward
|
||||
|
||||
|
||||
supported_models = ['Flux', 'CogVideoX', 'Mochi', 'LTX', 'HiDream', 'Lumina2']
|
||||
supported_models = ['Flux', 'Chroma', 'CogVideoX', 'Mochi', 'LTX', 'HiDream', 'Lumina2']
|
||||
|
||||
|
||||
def apply_teacache(p):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,325 @@
|
|||
from typing import Any, Dict, Optional, Union
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
||||
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def teacache_chroma_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_block_samples=None,
|
||||
controlnet_single_block_samples=None,
|
||||
return_dict: bool = True,
|
||||
controlnet_blocks_repeat: bool = False,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`ChromaTransformer2DModel`] forward method.
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
|
||||
input_vec = self.time_text_embed(timestep)
|
||||
pooled_temb = self.distilled_guidance_layer(input_vec)
|
||||
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if txt_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
txt_ids = txt_ids[0]
|
||||
if img_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
||||
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
||||
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
||||
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
||||
|
||||
if self.enable_teacache:
|
||||
inp = hidden_states.clone()
|
||||
input_vec_ = input_vec.clone()
|
||||
modulated_inp, _gate_msa, _shift_mlp, _scale_mlp, _gate_mlp = self.transformer_blocks[0].norm1(inp, emb=input_vec_)
|
||||
if self.cnt == 0 or self.cnt == self.num_steps-1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||
should_calc = False
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.cnt += 1
|
||||
if self.cnt == self.num_steps:
|
||||
self.cnt = 0
|
||||
|
||||
if self.enable_teacache:
|
||||
if not should_calc:
|
||||
hidden_states += self.previous_residual
|
||||
else:
|
||||
ori_hidden_states = hidden_states.clone()
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
img_offset = 3 * len(self.single_transformer_blocks)
|
||||
txt_offset = img_offset + 6 * len(self.transformer_blocks)
|
||||
img_modulation = img_offset + 6 * index_block
|
||||
text_modulation = txt_offset + 6 * index_block
|
||||
temb = torch.cat(
|
||||
(
|
||||
pooled_temb[:, img_modulation : img_modulation + 6],
|
||||
pooled_temb[:, text_modulation : text_modulation + 6],
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward4(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward4(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_block_samples is not None:
|
||||
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
# For Xlabs ControlNet.
|
||||
if controlnet_blocks_repeat:
|
||||
hidden_states = (
|
||||
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
start_idx = 3 * index_block
|
||||
temb = pooled_temb[:, start_idx : start_idx + 3]
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward2(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward2(block),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
self.previous_residual = hidden_states - ori_hidden_states
|
||||
else:
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
img_offset = 3 * len(self.single_transformer_blocks)
|
||||
txt_offset = img_offset + 6 * len(self.transformer_blocks)
|
||||
img_modulation = img_offset + 6 * index_block
|
||||
text_modulation = txt_offset + 6 * index_block
|
||||
temb = torch.cat(
|
||||
(
|
||||
pooled_temb[:, img_modulation : img_modulation + 6],
|
||||
pooled_temb[:, text_modulation : text_modulation + 6],
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward1(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward1(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_block_samples is not None:
|
||||
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
# For Xlabs ControlNet.
|
||||
if controlnet_blocks_repeat:
|
||||
hidden_states = (
|
||||
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
start_idx = 3 * index_block
|
||||
temb = pooled_temb[:, start_idx : start_idx + 3]
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward3(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward3(block),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
|
||||
temb = pooled_temb[:, -2:]
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
Loading…
Reference in New Issue