Initial support for Chroma

pull/3992/head
Enes Sadık Özbek 2025-06-18 00:38:17 +00:00
parent 26800a1ef9
commit 4b3ce06916
No known key found for this signature in database
GPG Key ID: 85CC9787265BBC7F
35 changed files with 3333 additions and 29 deletions

View File

@ -60,6 +60,8 @@ def guess_dct(dct: dict):
if has(dct, 'model.diffusion_model.joint_blocks') and len(list(has(dct, 'model.diffusion_model.joint_blocks'))) == 38:
return 'sd35-large'
if has(dct, 'model.diffusion_model.double_blocks') and len(list(has(dct, 'model.diffusion_model.double_blocks'))) == 19:
if has(dct, 'model.diffusion_model.distilled_guidance_layer'):
return 'chroma'
return 'flux-dev'
return None

View File

@ -0,0 +1,24 @@
{
"_class_name": "ChromaPipeline",
"_diffusers_version": "0.34.0.dev0",
"scheduler": [
"diffusers",
"FlowMatchEulerDiscreteScheduler"
],
"text_encoder": [
"transformers",
"T5EncoderModel"
],
"tokenizer": [
"transformers",
"T5Tokenizer"
],
"transformer": [
"diffusers",
"ChromaTransformer2DModel"
],
"vae": [
"diffusers",
"AutoencoderKL"
]
}

View File

@ -0,0 +1,18 @@
{
"_class_name": "FlowMatchEulerDiscreteScheduler",
"_diffusers_version": "0.34.0.dev0",
"base_image_seq_len": 256,
"base_shift": 0.5,
"invert_sigmas": false,
"max_image_seq_len": 4096,
"max_shift": 1.15,
"num_train_timesteps": 1000,
"shift": 3.0,
"shift_terminal": null,
"stochastic_sampling": false,
"time_shift_type": "exponential",
"use_beta_sigmas": false,
"use_dynamic_shifting": true,
"use_exponential_sigmas": false,
"use_karras_sigmas": false
}

View File

@ -0,0 +1,32 @@
{
"_name_or_path": "google/t5-v1_1-xxl",
"architectures": [
"T5EncoderModel"
],
"classifier_dropout": 0.0,
"d_ff": 10240,
"d_kv": 64,
"d_model": 4096,
"decoder_start_token_id": 0,
"dense_act_fn": "gelu_new",
"dropout_rate": 0.1,
"eos_token_id": 1,
"feed_forward_proj": "gated-gelu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 24,
"num_heads": 64,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_max_distance": 128,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.52.4",
"use_cache": true,
"vocab_size": 32128
}

View File

@ -0,0 +1,102 @@
{
"<extra_id_0>": 32099,
"<extra_id_10>": 32089,
"<extra_id_11>": 32088,
"<extra_id_12>": 32087,
"<extra_id_13>": 32086,
"<extra_id_14>": 32085,
"<extra_id_15>": 32084,
"<extra_id_16>": 32083,
"<extra_id_17>": 32082,
"<extra_id_18>": 32081,
"<extra_id_19>": 32080,
"<extra_id_1>": 32098,
"<extra_id_20>": 32079,
"<extra_id_21>": 32078,
"<extra_id_22>": 32077,
"<extra_id_23>": 32076,
"<extra_id_24>": 32075,
"<extra_id_25>": 32074,
"<extra_id_26>": 32073,
"<extra_id_27>": 32072,
"<extra_id_28>": 32071,
"<extra_id_29>": 32070,
"<extra_id_2>": 32097,
"<extra_id_30>": 32069,
"<extra_id_31>": 32068,
"<extra_id_32>": 32067,
"<extra_id_33>": 32066,
"<extra_id_34>": 32065,
"<extra_id_35>": 32064,
"<extra_id_36>": 32063,
"<extra_id_37>": 32062,
"<extra_id_38>": 32061,
"<extra_id_39>": 32060,
"<extra_id_3>": 32096,
"<extra_id_40>": 32059,
"<extra_id_41>": 32058,
"<extra_id_42>": 32057,
"<extra_id_43>": 32056,
"<extra_id_44>": 32055,
"<extra_id_45>": 32054,
"<extra_id_46>": 32053,
"<extra_id_47>": 32052,
"<extra_id_48>": 32051,
"<extra_id_49>": 32050,
"<extra_id_4>": 32095,
"<extra_id_50>": 32049,
"<extra_id_51>": 32048,
"<extra_id_52>": 32047,
"<extra_id_53>": 32046,
"<extra_id_54>": 32045,
"<extra_id_55>": 32044,
"<extra_id_56>": 32043,
"<extra_id_57>": 32042,
"<extra_id_58>": 32041,
"<extra_id_59>": 32040,
"<extra_id_5>": 32094,
"<extra_id_60>": 32039,
"<extra_id_61>": 32038,
"<extra_id_62>": 32037,
"<extra_id_63>": 32036,
"<extra_id_64>": 32035,
"<extra_id_65>": 32034,
"<extra_id_66>": 32033,
"<extra_id_67>": 32032,
"<extra_id_68>": 32031,
"<extra_id_69>": 32030,
"<extra_id_6>": 32093,
"<extra_id_70>": 32029,
"<extra_id_71>": 32028,
"<extra_id_72>": 32027,
"<extra_id_73>": 32026,
"<extra_id_74>": 32025,
"<extra_id_75>": 32024,
"<extra_id_76>": 32023,
"<extra_id_77>": 32022,
"<extra_id_78>": 32021,
"<extra_id_79>": 32020,
"<extra_id_7>": 32092,
"<extra_id_80>": 32019,
"<extra_id_81>": 32018,
"<extra_id_82>": 32017,
"<extra_id_83>": 32016,
"<extra_id_84>": 32015,
"<extra_id_85>": 32014,
"<extra_id_86>": 32013,
"<extra_id_87>": 32012,
"<extra_id_88>": 32011,
"<extra_id_89>": 32010,
"<extra_id_8>": 32091,
"<extra_id_90>": 32009,
"<extra_id_91>": 32008,
"<extra_id_92>": 32007,
"<extra_id_93>": 32006,
"<extra_id_94>": 32005,
"<extra_id_95>": 32004,
"<extra_id_96>": 32003,
"<extra_id_97>": 32002,
"<extra_id_98>": 32001,
"<extra_id_99>": 32000,
"<extra_id_9>": 32090
}

View File

@ -0,0 +1,125 @@
{
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>"
],
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}

Binary file not shown.

View File

@ -0,0 +1,940 @@
{
"add_prefix_space": true,
"added_tokens_decoder": {
"0": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32000": {
"content": "<extra_id_99>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32001": {
"content": "<extra_id_98>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32002": {
"content": "<extra_id_97>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32003": {
"content": "<extra_id_96>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32004": {
"content": "<extra_id_95>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32005": {
"content": "<extra_id_94>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32006": {
"content": "<extra_id_93>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32007": {
"content": "<extra_id_92>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32008": {
"content": "<extra_id_91>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32009": {
"content": "<extra_id_90>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32010": {
"content": "<extra_id_89>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32011": {
"content": "<extra_id_88>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32012": {
"content": "<extra_id_87>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32013": {
"content": "<extra_id_86>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32014": {
"content": "<extra_id_85>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32015": {
"content": "<extra_id_84>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32016": {
"content": "<extra_id_83>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32017": {
"content": "<extra_id_82>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32018": {
"content": "<extra_id_81>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32019": {
"content": "<extra_id_80>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32020": {
"content": "<extra_id_79>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32021": {
"content": "<extra_id_78>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32022": {
"content": "<extra_id_77>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32023": {
"content": "<extra_id_76>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32024": {
"content": "<extra_id_75>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32025": {
"content": "<extra_id_74>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32026": {
"content": "<extra_id_73>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32027": {
"content": "<extra_id_72>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32028": {
"content": "<extra_id_71>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32029": {
"content": "<extra_id_70>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32030": {
"content": "<extra_id_69>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32031": {
"content": "<extra_id_68>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32032": {
"content": "<extra_id_67>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32033": {
"content": "<extra_id_66>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32034": {
"content": "<extra_id_65>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32035": {
"content": "<extra_id_64>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32036": {
"content": "<extra_id_63>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32037": {
"content": "<extra_id_62>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32038": {
"content": "<extra_id_61>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32039": {
"content": "<extra_id_60>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32040": {
"content": "<extra_id_59>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32041": {
"content": "<extra_id_58>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32042": {
"content": "<extra_id_57>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32043": {
"content": "<extra_id_56>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32044": {
"content": "<extra_id_55>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32045": {
"content": "<extra_id_54>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32046": {
"content": "<extra_id_53>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32047": {
"content": "<extra_id_52>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32048": {
"content": "<extra_id_51>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32049": {
"content": "<extra_id_50>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32050": {
"content": "<extra_id_49>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32051": {
"content": "<extra_id_48>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32052": {
"content": "<extra_id_47>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32053": {
"content": "<extra_id_46>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32054": {
"content": "<extra_id_45>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32055": {
"content": "<extra_id_44>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32056": {
"content": "<extra_id_43>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32057": {
"content": "<extra_id_42>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32058": {
"content": "<extra_id_41>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32059": {
"content": "<extra_id_40>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32060": {
"content": "<extra_id_39>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32061": {
"content": "<extra_id_38>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32062": {
"content": "<extra_id_37>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32063": {
"content": "<extra_id_36>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32064": {
"content": "<extra_id_35>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32065": {
"content": "<extra_id_34>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32066": {
"content": "<extra_id_33>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32067": {
"content": "<extra_id_32>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32068": {
"content": "<extra_id_31>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32069": {
"content": "<extra_id_30>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32070": {
"content": "<extra_id_29>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32071": {
"content": "<extra_id_28>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32072": {
"content": "<extra_id_27>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32073": {
"content": "<extra_id_26>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32074": {
"content": "<extra_id_25>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32075": {
"content": "<extra_id_24>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32076": {
"content": "<extra_id_23>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32077": {
"content": "<extra_id_22>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32078": {
"content": "<extra_id_21>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32079": {
"content": "<extra_id_20>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32080": {
"content": "<extra_id_19>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32081": {
"content": "<extra_id_18>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32082": {
"content": "<extra_id_17>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32083": {
"content": "<extra_id_16>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32084": {
"content": "<extra_id_15>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32085": {
"content": "<extra_id_14>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32086": {
"content": "<extra_id_13>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32087": {
"content": "<extra_id_12>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32088": {
"content": "<extra_id_11>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32089": {
"content": "<extra_id_10>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32090": {
"content": "<extra_id_9>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32091": {
"content": "<extra_id_8>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32092": {
"content": "<extra_id_7>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32093": {
"content": "<extra_id_6>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32094": {
"content": "<extra_id_5>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32095": {
"content": "<extra_id_4>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32096": {
"content": "<extra_id_3>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32097": {
"content": "<extra_id_2>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32098": {
"content": "<extra_id_1>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"32099": {
"content": "<extra_id_0>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>"
],
"clean_up_tokenization_spaces": true,
"eos_token": "</s>",
"extra_ids": 100,
"legacy": true,
"model_max_length": 512,
"pad_token": "<pad>",
"sp_model_kwargs": {},
"tokenizer_class": "T5Tokenizer",
"unk_token": "<unk>"
}

View File

@ -0,0 +1,20 @@
{
"_class_name": "ChromaTransformer2DModel",
"_diffusers_version": "0.34.0.dev0",
"approximator_hidden_dim": 5120,
"approximator_in_factor": 16,
"approximator_layers": 5,
"attention_head_dim": 128,
"axes_dims_rope": [
16,
56,
56
],
"in_channels": 64,
"joint_attention_dim": 4096,
"num_attention_heads": 24,
"num_layers": 19,
"num_single_layers": 38,
"out_channels": null,
"patch_size": 1
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,37 @@
{
"_class_name": "AutoencoderKL",
"_diffusers_version": "0.34.0.dev0",
"act_fn": "silu",
"block_out_channels": [
128,
256,
512,
512
],
"down_block_types": [
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D"
],
"force_upcast": true,
"in_channels": 3,
"latent_channels": 16,
"latents_mean": null,
"latents_std": null,
"layers_per_block": 2,
"mid_block_add_attention": true,
"norm_num_groups": 32,
"out_channels": 3,
"sample_size": 1024,
"scaling_factor": 0.3611,
"shift_factor": 0.1159,
"up_block_types": [
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D"
],
"use_post_quant_conv": false,
"use_quant_conv": false
}

View File

@ -107,6 +107,9 @@ def make_meta(fn, maxrank, rank_ratio):
elif shared.sd_model_type == "f1":
meta["model_spec.architecture"] = "flux-1-dev/lora"
meta["ss_base_model_version"] = "flux1"
elif shared.sd_model_type == "chroma":
meta["model_spec.architecture"] = "chroma/lora"
meta["ss_base_model_version"] = "chroma"
elif shared.sd_model_type == "sc":
meta["model_spec.architecture"] = "stable-cascade-v1-prior/lora"
return meta

View File

@ -139,7 +139,7 @@ def load_network(name, network_on_disk) -> network.Network:
net = network.Network(name, network_on_disk)
net.mtime = os.path.getmtime(network_on_disk.filename)
sd = sd_models.read_state_dict(network_on_disk.filename, what='network')
if shared.sd_model_type == 'f1': # if kohya flux lora, convert state_dict
if shared.sd_model_type in ['f1', 'chroma']: # if kohya flux lora, convert state_dict
sd = lora_convert._convert_kohya_flux_lora_to_diffusers(sd) or sd # pylint: disable=protected-access
assign_network_names_to_compvis_modules(shared.sd_model)
keys_failed_to_match = {}

View File

@ -546,7 +546,7 @@ def check_diffusers():
t_start = time.time()
if args.skip_all or args.skip_git or args.experimental:
return
sha = '8adc6003ba4dbf5b61bb4f1ce571e9e55e145a99' # diffusers commit hash
sha = '900653c814cfa431877ae46548fe496506bda2ad' # diffusers commit hash
pkg = pkg_resources.working_set.by_key.get('diffusers', None)
minor = int(pkg.version.split('.')[1] if pkg is not None else 0)
cur = opts.get('diffusers_version', '') if minor > 0 else ''

View File

@ -101,7 +101,7 @@ def map_model_name(name: str):
return 'xl'
if name == 'sd3':
return 'v3'
if name == 'f1':
if name in ['f1', 'chroma']:
return 'fx'
return name

View File

@ -107,6 +107,9 @@ def make_meta(fn, maxrank, rank_ratio):
elif shared.sd_model_type == "f1":
meta["model_spec.architecture"] = "flux-1-dev/lora"
meta["ss_base_model_version"] = "flux1"
elif shared.sd_model_type == "chroma":
meta["model_spec.architecture"] = "chroma/lora"
meta["ss_base_model_version"] = "chroma"
elif shared.sd_model_type == "sc":
meta["model_spec.architecture"] = "stable-cascade-v1-prior/lora"
return meta

View File

@ -85,7 +85,7 @@ def load_safetensors(name, network_on_disk) -> Union[network.Network, None]:
net = network.Network(name, network_on_disk)
net.mtime = os.path.getmtime(network_on_disk.filename)
state_dict = sd_models.read_state_dict(network_on_disk.filename, what='network')
if shared.sd_model_type == 'f1': # if kohya flux lora, convert state_dict
if shared.sd_model_type in ['f1', 'chroma']: # if kohya flux lora, convert state_dict
state_dict = lora_convert._convert_kohya_flux_lora_to_diffusers(state_dict) or state_dict # pylint: disable=protected-access
if shared.sd_model_type == 'sd3': # if kohya flux lora, convert state_dict
try:

View File

@ -18,6 +18,7 @@ class SdVersion(enum.Enum):
SC = 5
F1 = 6
HV = 7
CHROMA = 8
class NetworkOnDisk:
@ -59,6 +60,8 @@ class NetworkOnDisk:
return 'f1'
if base.startswith("hunyuan_video"):
return 'hv'
if base.startswith("chroma"):
return 'chroma'
if arch.startswith("stable-diffusion-v1"):
return 'sd1'
@ -70,6 +73,8 @@ class NetworkOnDisk:
return 'f1'
if arch.startswith("hunyuan-video"):
return 'hv'
if arch.startswith("chroma"):
return 'chroma'
if "v1-5" in str(self.metadata.get('ss_sd_model_name', "")):
return 'sd1'
@ -79,6 +84,8 @@ class NetworkOnDisk:
return 'f1'
if 'xl' in self.name.lower():
return 'xl'
if 'chroma' in self.name.lower():
return 'chroma'
return ''

339
modules/model_chroma.py Normal file
View File

@ -0,0 +1,339 @@
import os
import json
import torch
import diffusers
import transformers
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download, auth_check
from modules import shared, errors, devices, modelloader, sd_models, sd_unet, model_te, model_quant, sd_hijack_te
debug = shared.log.trace if os.environ.get('SD_LOAD_DEBUG', None) is not None else lambda *args, **kwargs: None
def load_chroma_quanto(checkpoint_info):
transformer, text_encoder = None, None
quanto = model_quant.load_quanto('Load model: type=Chroma')
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
try:
quantization_map = os.path.join(repo_path, "transformer", "quantization_map.json")
debug(f'Load model: type=Chroma quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="transformer"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info.name)
quantization_map = hf_hub_download(repo_id, subfolder='transformer', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
state_dict = load_file(os.path.join(repo_path, "transformer", "diffusion_pytorch_model.safetensors"))
dtype = state_dict['context_embedder.bias'].dtype
with torch.device("meta"):
transformer = diffusers.ChromaTransformer2DModel.from_config(os.path.join(repo_path, "transformer", "config.json")).to(dtype=dtype)
quanto.requantize(transformer, state_dict, quantization_map, device=torch.device("cpu"))
if shared.opts.diffusers_eval:
transformer.eval()
transformer_dtype = transformer.dtype
if transformer_dtype != devices.dtype:
try:
transformer = transformer.to(dtype=devices.dtype)
except Exception:
shared.log.error(f"Load model: type=Chroma Failed to cast transformer to {devices.dtype}, set dtype to {transformer_dtype}")
except Exception as e:
shared.log.error(f"Load model: type=Chroma failed to load Quanto transformer: {e}")
if debug:
errors.display(e, 'Chroma Quanto:')
try:
quantization_map = os.path.join(repo_path, "text_encoder", "quantization_map.json")
debug(f'Load model: type=Chroma quantization map="{quantization_map}" repo="{checkpoint_info.name}" component="text_encoder"')
if not os.path.exists(quantization_map):
repo_id = sd_models.path_to_repo(checkpoint_info.name)
quantization_map = hf_hub_download(repo_id, subfolder='text_encoder', filename='quantization_map.json', cache_dir=shared.opts.diffusers_dir)
with open(quantization_map, "r", encoding='utf8') as f:
quantization_map = json.load(f)
with open(os.path.join(repo_path, "text_encoder", "config.json"), encoding='utf8') as f:
t5_config = transformers.T5Config(**json.load(f))
state_dict = load_file(os.path.join(repo_path, "text_encoder", "model.safetensors"))
dtype = state_dict['encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'].dtype
with torch.device("meta"):
text_encoder = transformers.T5EncoderModel(t5_config).to(dtype=dtype)
quanto.requantize(text_encoder, state_dict, quantization_map, device=torch.device("cpu"))
if shared.opts.diffusers_eval:
text_encoder.eval()
text_encoder_dtype = text_encoder.dtype
if text_encoder_dtype != devices.dtype:
try:
text_encoder = text_encoder.to(dtype=devices.dtype)
except Exception:
shared.log.error(f"Load model: type=Chroma Failed to cast text encoder to {devices.dtype}, set dtype to {text_encoder_dtype}")
except Exception as e:
shared.log.error(f"Load model: type=Chroma failed to load Quanto text encoder: {e}")
if debug:
errors.display(e, 'Chroma Quanto:')
return transformer, text_encoder
def load_chroma_bnb(checkpoint_info, diffusers_load_config): # pylint: disable=unused-argument
transformer, text_encoder = None, None
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
model_quant.load_bnb('Load model: type=Chroma')
quant = model_quant.get_quant(repo_path)
try:
# we ignore the distilled guidance layer because it degrades quality too much
# see: https://github.com/huggingface/diffusers/pull/11698#issuecomment-2969717180 for more details
if quant == 'fp8':
quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["distilled_guidance_layer"], bnb_4bit_compute_dtype=devices.dtype)
debug(f'Quantization: {quantization_config}')
transformer = diffusers.ChromaTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'fp4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, llm_int8_skip_modules=["distilled_guidance_layer"], bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'fp4')
debug(f'Quantization: {quantization_config}')
transformer = diffusers.ChromaTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
elif quant == 'nf4':
quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True, llm_int8_skip_modules=["distilled_guidance_layer"], bnb_4bit_compute_dtype=devices.dtype, bnb_4bit_quant_type= 'nf4')
debug(f'Quantization: {quantization_config}')
transformer = diffusers.ChromaTransformer2DModel.from_single_file(repo_path, **diffusers_load_config, quantization_config=quantization_config)
else:
transformer = diffusers.ChromaTransformer2DModel.from_single_file(repo_path, **diffusers_load_config)
except Exception as e:
shared.log.error(f"Load model: type=Chroma failed to load BnB transformer: {e}")
transformer, text_encoder = None, None
if debug:
errors.display(e, 'Chroma:')
return transformer, text_encoder
def load_quants(kwargs, pretrained_model_name_or_path, cache_dir, allow_quant):
try:
if 'transformer' not in kwargs and model_quant.check_nunchaku('Transformer'):
raise NotImplementedError('Nunchaku does not support Chroma Model yet. See https://github.com/mit-han-lab/nunchaku/issues/167')
elif 'transformer' not in kwargs and model_quant.check_quant('Transformer'):
quant_args = model_quant.create_config(allow=allow_quant, module='Transformer')
if quant_args:
if os.path.isfile(pretrained_model_name_or_path):
kwargs['transformer'] = diffusers.ChromaTransformer2DModel.from_single_file(pretrained_model_name_or_path, cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args)
pass
else:
kwargs['transformer'] = diffusers.ChromaTransformer2DModel.from_pretrained(pretrained_model_name_or_path, subfolder="transformer", cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args)
if 'text_encoder' not in kwargs and model_quant.check_nunchaku('TE'):
import nunchaku
nunchaku_precision = nunchaku.utils.get_precision()
nunchaku_repo = 'mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors'
shared.log.debug(f'Load module: quant=Nunchaku module=t5 repo="{nunchaku_repo}" precision={nunchaku_precision}')
kwargs['text_encoder'] = nunchaku.NunchakuT5EncoderModel.from_pretrained(nunchaku_repo, torch_dtype=devices.dtype)
elif 'text_encoder' not in kwargs and model_quant.check_quant('TE'):
quant_args = model_quant.create_config(allow=allow_quant, module='TE')
if quant_args:
if os.path.isfile(pretrained_model_name_or_path):
kwargs['text_encoder'] = transformers.T5EncoderModel.from_single_file(pretrained_model_name_or_path, cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args)
else:
kwargs['text_encoder'] = transformers.T5EncoderModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", cache_dir=cache_dir, torch_dtype=devices.dtype, **quant_args)
except Exception as e:
shared.log.error(f'Quantization: {e}')
errors.display(e, 'Quantization:')
return kwargs
def load_transformer(file_path): # triggered by opts.sd_unet change
if file_path is None or not os.path.exists(file_path):
return None
transformer = None
quant = model_quant.get_quant(file_path)
diffusers_load_config = {
"low_cpu_mem_usage": True,
"torch_dtype": devices.dtype,
"cache_dir": shared.opts.hfcache_dir,
}
if quant is not None and quant != 'none':
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} prequant={quant} dtype={devices.dtype}')
if 'gguf' in file_path.lower():
from modules import ggml
_transformer = ggml.load_gguf(file_path, cls=diffusers.ChromaTransformer2DModel, compute_dtype=devices.dtype)
if _transformer is not None:
transformer = _transformer
elif quant == 'qint8' or quant == 'qint4':
_transformer, _text_encoder = load_chroma_quanto(file_path)
if _transformer is not None:
transformer = _transformer
elif quant == 'fp8' or quant == 'fp4' or quant == 'nf4':
_transformer, _text_encoder = load_chroma_bnb(file_path, diffusers_load_config)
if _transformer is not None:
transformer = _transformer
elif 'nf4' in quant: # TODO chroma: loader for civitai nf4 models
from modules.model_chroma_nf4 import load_chroma_nf4
_transformer, _text_encoder = load_chroma_nf4(file_path, prequantized=True)
if _transformer is not None:
transformer = _transformer
else:
quant_args = model_quant.create_bnb_config({})
if quant_args:
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=bnb dtype={devices.dtype}')
from modules.model_chroma_nf4 import load_chroma_nf4
transformer, _text_encoder = load_chroma_nf4(file_path, prequantized=False)
if transformer is not None:
return transformer
quant_args = model_quant.create_config(module='Transformer')
if quant_args:
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=torchao dtype={devices.dtype}')
transformer = diffusers.ChromaTransformer2DModel.from_single_file(file_path, **diffusers_load_config, **quant_args)
if transformer is not None:
return transformer
shared.log.info(f'Load module: type=UNet/Transformer file="{file_path}" offload={shared.opts.diffusers_offload_mode} quant=none dtype={devices.dtype}')
# TODO chroma transformer from-single-file with quant
# shared.log.warning('Load module: type=UNet/Transformer does not support load-time quantization')
transformer = diffusers.ChromaTransformer2DModel.from_single_file(file_path, **diffusers_load_config)
if transformer is None:
shared.log.error('Failed to load UNet model')
shared.opts.sd_unet = 'Default'
return transformer
def load_chroma(checkpoint_info, diffusers_load_config): # triggered by opts.sd_checkpoint change
fn = checkpoint_info.path
repo_id = sd_models.path_to_repo(checkpoint_info.name)
login = modelloader.hf_login()
try:
auth_check(repo_id)
except Exception as e:
repo_id = None
if not os.path.exists(fn):
shared.log.error(f'Load model: repo="{repo_id}" login={login} {e}')
return None
prequantized = model_quant.get_quant(checkpoint_info.path)
shared.log.debug(f'Load model: type=Chroma model="{checkpoint_info.name}" repo={repo_id or "none"} unet="{shared.opts.sd_unet}" te="{shared.opts.sd_text_encoder}" vae="{shared.opts.sd_vae}" quant={prequantized} offload={shared.opts.diffusers_offload_mode} dtype={devices.dtype}')
debug(f'Load model: type=Chroma config={diffusers_load_config}')
transformer = None
text_encoder = None
vae = None
# unload current model
sd_models.unload_model_weights()
shared.sd_model = None
devices.torch_gc(force=True)
if shared.opts.teacache_enabled:
from modules import teacache
shared.log.debug(f'Transformers cache: type=teacache patch=forward cls={diffusers.ChromaTransformer2DModel.__name__}')
diffusers.ChromaTransformer2DModel.forward = teacache.teacache_chroma_forward # patch must be done before transformer is loaded
# load overrides if any
if shared.opts.sd_unet != 'Default':
try:
debug(f'Load model: type=Chroma unet="{shared.opts.sd_unet}"')
transformer = load_transformer(sd_unet.unet_dict[shared.opts.sd_unet])
if transformer is None:
shared.opts.sd_unet = 'Default'
sd_unet.failed_unet.append(shared.opts.sd_unet)
except Exception as e:
shared.log.error(f"Load model: type=Chroma failed to load UNet: {e}")
shared.opts.sd_unet = 'Default'
if debug:
errors.display(e, 'Chroma UNet:')
if shared.opts.sd_text_encoder != 'Default':
try:
debug(f'Load model: type=Chroma te="{shared.opts.sd_text_encoder}"')
from modules.model_te import load_t5
text_encoder = load_t5(name=shared.opts.sd_text_encoder, cache_dir=shared.opts.diffusers_dir)
except Exception as e:
shared.log.error(f"Load model: type=Chroma failed to load T5: {e}")
shared.opts.sd_text_encoder = 'Default'
if debug:
errors.display(e, 'Chroma T5:')
if shared.opts.sd_vae != 'Default' and shared.opts.sd_vae != 'Automatic':
try:
debug(f'Load model: type=Chroma vae="{shared.opts.sd_vae}"')
from modules import sd_vae
# vae = sd_vae.load_vae_diffusers(None, sd_vae.vae_dict[shared.opts.sd_vae], 'override')
vae_file = sd_vae.vae_dict[shared.opts.sd_vae]
if os.path.exists(vae_file):
vae_config = os.path.join('configs', 'chroma', 'vae', 'config.json')
vae = diffusers.AutoencoderKL.from_single_file(vae_file, config=vae_config, **diffusers_load_config)
except Exception as e:
shared.log.error(f"Load model: type=Chroma failed to load VAE: {e}")
shared.opts.sd_vae = 'Default'
if debug:
errors.display(e, 'Chroma VAE:')
# load quantized components if any
if prequantized == 'nf4':
try:
from modules.model_chroma_nf4 import load_chroma_nf4
_transformer, _text_encoder = load_chroma_nf4(checkpoint_info)
if _transformer is not None:
transformer = _transformer
if _text_encoder is not None:
text_encoder = _text_encoder
except Exception as e:
shared.log.error(f"Load model: type=Chroma failed to load NF4 components: {e}")
if debug:
errors.display(e, 'Chroma NF4:')
if prequantized == 'qint8' or prequantized == 'qint4':
try:
_transformer, _text_encoder = load_chroma_quanto(checkpoint_info)
if _transformer is not None:
transformer = _transformer
if _text_encoder is not None:
text_encoder = _text_encoder
except Exception as e:
shared.log.error(f"Load model: type=Chroma failed to load Quanto components: {e}")
if debug:
errors.display(e, 'Chroma Quanto:')
# initialize pipeline with pre-loaded components
kwargs = {}
if transformer is not None:
kwargs['transformer'] = transformer
sd_unet.loaded_unet = shared.opts.sd_unet
if text_encoder is not None:
kwargs['text_encoder'] = text_encoder
model_te.loaded_te = shared.opts.sd_text_encoder
if vae is not None:
kwargs['vae'] = vae
# Todo: atm only ChromaPipeline is implemented in diffusers.
# Need to add ChromaFillPipeline, ChromaControlPipeline, ChromaImg2ImgPipeline etc when available.
# Chroma will support inpainting *after* its training has finished:
# https://huggingface.co/lodestones/Chroma/discussions/28#6826dd2ed86f53ff983add5c
cls = diffusers.ChromaPipeline
shared.log.debug(f'Load model: type=Chroma cls={cls.__name__} preloaded={list(kwargs)} revision={diffusers_load_config.get("revision", None)}')
for c in kwargs:
if getattr(kwargs[c], 'quantization_method', None) is not None or getattr(kwargs[c], 'gguf', None) is not None:
shared.log.debug(f'Load model: type=Chroma component={c} dtype={kwargs[c].dtype} quant={getattr(kwargs[c], "quantization_method", None) or getattr(kwargs[c], "gguf", None)}')
if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32:
try:
kwargs[c] = kwargs[c].to(dtype=devices.dtype)
shared.log.warning(f'Load model: type=Chroma component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype} recast')
except Exception:
pass
allow_quant = 'gguf' not in (sd_unet.loaded_unet or '') and (prequantized is None or prequantized == 'none')
if (fn is None) or (not os.path.exists(fn) or os.path.isdir(fn)):
kwargs = load_quants(kwargs, repo_id or fn, cache_dir=shared.opts.diffusers_dir, allow_quant=allow_quant)
# kwargs = model_quant.create_config(kwargs, allow_quant)
if fn.endswith('.safetensors') and os.path.isfile(fn):
pipe = diffusers.ChromaPipeline.from_single_file(fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
else:
pipe = cls.from_pretrained(repo_id or fn, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
if shared.opts.teacache_enabled and model_quant.check_nunchaku('Transformer'):
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
apply_cache_on_pipe(pipe, residual_diff_threshold=0.12)
# release memory
transformer = None
text_encoder = None
vae = None
for k in kwargs.keys():
kwargs[k] = None
sd_hijack_te.init_hijack(pipe)
devices.torch_gc(force=True)
return pipe

202
modules/model_chroma_nf4.py Normal file
View File

@ -0,0 +1,202 @@
"""
Copied from: https://github.com/huggingface/diffusers/issues/9165
+ adjusted for Chroma by skipping distilled guidance layer
"""
import os
import torch
import torch.nn as nn
from transformers.quantizers.quantizers_utils import get_module_from_name
from huggingface_hub import hf_hub_download
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from diffusers.loaders.single_file_utils import convert_chroma_transformer_checkpoint_to_diffusers
import safetensors.torch
from modules import shared, devices, model_quant
debug = os.environ.get('SD_LOAD_DEBUG', None) is not None
def _replace_with_bnb_linear(
model,
method="nf4",
has_been_replaced=False,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
bnb = model_quant.load_bnb('Load model: type=Chroma')
for name, module in model.named_children():
# we ignore the distilled guidance layer because it degrades quality too much
# see: https://github.com/huggingface/diffusers/pull/11698#issuecomment-2969717180 for more details
if "distilled_guidance_layer" in name:
continue
if isinstance(module, nn.Linear):
with init_empty_weights():
in_features = module.in_features
out_features = module.out_features
if method == "llm_int8":
model._modules[name] = bnb.nn.Linear8bitLt( # pylint: disable=protected-access
in_features,
out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=6.0,
)
has_been_replaced = True
else:
model._modules[name] = bnb.nn.Linear4bit( # pylint: disable=protected-access
in_features,
out_features,
module.bias is not None,
compute_dtype=devices.dtype,
compress_statistics=False,
quant_type="nf4",
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module) # pylint: disable=protected-access
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False) # pylint: disable=protected-access
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_bnb_linear(
module,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
return model, has_been_replaced
def check_quantized_param(
model,
param_name: str,
) -> bool:
bnb = model_quant.load_bnb('Load model: type=Chroma')
module, tensor_name = get_module_from_name(model, param_name)
if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit): # pylint: disable=protected-access
# Add here check for loaded components' dtypes once serialization is implemented
return True
elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
# bias could be loaded by regular set_module_tensor_to_device() from accelerate,
# but it would wrongly use uninitialized weight there.
return True
else:
return False
def create_quantized_param(
model,
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict=None,
unexpected_keys=None,
pre_quantized=False
):
bnb = model_quant.load_bnb('Load model: type=Chroma')
module, tensor_name = get_module_from_name(model, param_name)
if tensor_name not in module._parameters: # pylint: disable=protected-access
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
old_value = getattr(module, tensor_name)
if tensor_name == "bias":
if param_value is None:
new_value = old_value.to(target_device)
else:
new_value = param_value.to(target_device)
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
module._parameters[tensor_name] = new_value # pylint: disable=protected-access
return
if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): # pylint: disable=protected-access
raise ValueError("this function only loads `Linear4bit components`")
if (
old_value.device == torch.device("meta")
and target_device not in ["meta", torch.device("meta")]
and param_value is None
):
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
if pre_quantized:
if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (param_name + ".quant_state.bitsandbytes__nf4" not in state_dict):
raise ValueError(f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components.")
quantized_stats = {}
for k, v in state_dict.items():
# `startswith` to counter for edge cases where `param_name`
# substring can be present in multiple places in the `state_dict`
if param_name + "." in k and k.startswith(param_name):
quantized_stats[k] = v
if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k)
new_value = bnb.nn.Params4bit.from_prequantized(
data=param_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=target_device,
)
else:
new_value = param_value.to("cpu")
kwargs = old_value.__dict__
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
module._parameters[tensor_name] = new_value # pylint: disable=protected-access
def load_chroma_nf4(checkpoint_info, prequantized: bool = True):
transformer = None
text_encoder = None
if isinstance(checkpoint_info, str):
repo_path = checkpoint_info
else:
repo_path = checkpoint_info.path
if os.path.exists(repo_path) and os.path.isfile(repo_path):
ckpt_path = repo_path
elif os.path.exists(repo_path) and os.path.isdir(repo_path) and os.path.exists(os.path.join(repo_path, "diffusion_pytorch_model.safetensors")):
ckpt_path = os.path.join(repo_path, "diffusion_pytorch_model.safetensors")
else:
ckpt_path = hf_hub_download(repo_path, filename="diffusion_pytorch_model.safetensors", cache_dir=shared.opts.diffusers_dir)
original_state_dict = safetensors.torch.load_file(ckpt_path)
try:
converted_state_dict = convert_chroma_transformer_checkpoint_to_diffusers(original_state_dict)
except Exception as e:
shared.log.error(f"Load model: type=Chroma Failed to convert UNET: {e}")
if debug:
from modules import errors
errors.display(e, 'Chroma convert:')
converted_state_dict = original_state_dict
with init_empty_weights():
from diffusers import ChromaTransformer2DModel
config = ChromaTransformer2DModel.load_config(os.path.join('configs', 'chroma'), subfolder="transformer")
transformer = ChromaTransformer2DModel.from_config(config).to(devices.dtype)
expected_state_dict_keys = list(transformer.state_dict().keys())
_replace_with_bnb_linear(transformer, "nf4")
try:
for param_name, param in converted_state_dict.items():
if param_name not in expected_state_dict_keys:
continue
is_param_float8_e4m3fn = hasattr(torch, "float8_e4m3fn") and param.dtype == torch.float8_e4m3fn
if torch.is_floating_point(param) and not is_param_float8_e4m3fn:
param = param.to(devices.dtype)
if not check_quantized_param(transformer, param_name):
set_module_tensor_to_device(transformer, param_name, device=0, value=param)
else:
create_quantized_param(transformer, param, param_name, target_device=0, state_dict=original_state_dict, pre_quantized=prequantized)
except Exception as e:
transformer, text_encoder = None
shared.log.error(f"Load model: type=Chroma failed to load UNET: {e}")
if debug:
from modules import errors
errors.display(e, 'Chroma:')
del original_state_dict
devices.torch_gc(force=True)
return transformer, text_encoder

View File

@ -427,8 +427,12 @@ def optimum_quanto_model(model, op=None, sd_model=None, weights=None, activation
from modules import devices, shared
quanto = load_quanto('Quantize model: type=Optimum Quanto')
global quant_last_model_name, quant_last_model_device # pylint: disable=global-statement
if sd_model is not None and "Flux" in sd_model.__class__.__name__: # LayerNorm is not supported
if sd_model is not None and "Flux" in sd_model.__class__.__name__ or "Chroma" in sd_model.__class__.__name__: # LayerNorm is not supported
exclude_list = ["transformer_blocks.*.norm1.norm", "transformer_blocks.*.norm2", "transformer_blocks.*.norm1_context.norm", "transformer_blocks.*.norm2_context", "single_transformer_blocks.*.norm.norm", "norm_out.norm"]
if "Chroma" in sd_model.__class__.__name__:
# we ignore the distilled guidance layer because it degrades quality too much
# see: https://github.com/huggingface/diffusers/pull/11698#issuecomment-2969717180 for more details
exclude_list.append("distilled_guidance_layer.*")
else:
exclude_list = None
weights = getattr(quanto, weights) if weights is not None else getattr(quanto, shared.opts.optimum_quanto_weights_type)

View File

@ -27,6 +27,8 @@ def get_model_type(pipe):
model_type = 'sc'
elif "AuraFlow" in name:
model_type = 'auraflow'
elif 'Chroma' in name:
model_type = 'chroma'
elif "Flux" in name or "Flex1" in name or "Flex2" in name:
model_type = 'f1'
elif "Lumina2" in name:

View File

@ -150,6 +150,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
'StableDiffusion' in model.__class__.__name__ or
'StableCascade' in model.__class__.__name__ or
'Flux' in model.__class__.__name__ or
'Chroma' in model.__class__.__name__ or
'HiDreamImagePipeline' in model.__class__.__name__ # hidream-e1 has different embeds
):
try:
@ -174,15 +175,14 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
args['prompt_embeds_llama3'] = prompt_embeds[1]
elif hasattr(model, 'text_encoder') and hasattr(model, 'tokenizer') and 'prompt_embeds' in possible and prompt_parser_diffusers.embedder is not None:
args['prompt_embeds'] = prompt_parser_diffusers.embedder('prompt_embeds')
if prompt_parser_diffusers.embedder is not None:
if 'StableCascade' in model.__class__.__name__:
args['prompt_embeds_pooled'] = prompt_parser_diffusers.embedder('positive_pooleds').unsqueeze(0)
elif 'XL' in model.__class__.__name__:
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
elif 'StableDiffusion3' in model.__class__.__name__:
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
elif 'Flux' in model.__class__.__name__:
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
if 'StableCascade' in model.__class__.__name__:
args['prompt_embeds_pooled'] = prompt_parser_diffusers.embedder('positive_pooleds').unsqueeze(0)
elif 'XL' in model.__class__.__name__:
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
elif 'StableDiffusion3' in model.__class__.__name__:
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
elif 'Flux' in model.__class__.__name__:
args['pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('positive_pooleds')
else:
args['prompt'] = prompts
if 'negative_prompt' in possible:
@ -193,13 +193,12 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t
args['negative_prompt_embeds_llama3'] = negative_prompt_embeds[1]
elif hasattr(model, 'text_encoder') and hasattr(model, 'tokenizer') and 'negative_prompt_embeds' in possible and prompt_parser_diffusers.embedder is not None:
args['negative_prompt_embeds'] = prompt_parser_diffusers.embedder('negative_prompt_embeds')
if prompt_parser_diffusers.embedder is not None:
if 'StableCascade' in model.__class__.__name__:
args['negative_prompt_embeds_pooled'] = prompt_parser_diffusers.embedder('negative_pooleds').unsqueeze(0)
elif 'XL' in model.__class__.__name__:
args['negative_pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('negative_pooleds')
elif 'StableDiffusion3' in model.__class__.__name__:
args['negative_pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('negative_pooleds')
if 'StableCascade' in model.__class__.__name__:
args['negative_prompt_embeds_pooled'] = prompt_parser_diffusers.embedder('negative_pooleds').unsqueeze(0)
elif 'XL' in model.__class__.__name__:
args['negative_pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('negative_pooleds')
elif 'StableDiffusion3' in model.__class__.__name__:
args['negative_pooled_prompt_embeds'] = prompt_parser_diffusers.embedder('negative_pooleds')
else:
if 'PixArtSigmaPipeline' in model.__class__.__name__: # pixart-sigma pipeline throws list-of-list for negative prompt
args['negative_prompt'] = negative_prompts[0]

View File

@ -7,7 +7,7 @@ import torch
from compel.embeddings_provider import BaseTextualInversionManager, EmbeddingsProvider
from transformers import PreTrainedTokenizer
from modules import shared, prompt_parser, devices, sd_models
from modules.prompt_parser_xhinker import get_weighted_text_embeddings_sd15, get_weighted_text_embeddings_sdxl_2p, get_weighted_text_embeddings_sd3, get_weighted_text_embeddings_flux1
from modules.prompt_parser_xhinker import get_weighted_text_embeddings_sd15, get_weighted_text_embeddings_sdxl_2p, get_weighted_text_embeddings_sd3, get_weighted_text_embeddings_flux1, get_weighted_text_embeddings_chroma
debug_enabled = os.environ.get('SD_PROMPT_DEBUG', None)
debug = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
@ -27,6 +27,7 @@ def prompt_compatible(pipe = None):
'DemoFusion' not in pipe.__class__.__name__ and
'StableCascade' not in pipe.__class__.__name__ and
'Flux' not in pipe.__class__.__name__ and
'Chroma' not in pipe.__class__.__name__ and
'HiDreamImage' not in pipe.__class__.__name__
):
shared.log.warning(f"Prompt parser not supported: {pipe.__class__.__name__}")
@ -510,6 +511,10 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c
prompt_embeds, pooled_prompt_embeds, _ = pipe.encode_prompt(prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=1)
return prompt_embeds, pooled_prompt_embeds, None, None # no negative support
if "Chroma" in pipe.__class__.__name__: # does not use clip and has no pooled embeds
prompt_embeds, _, _, negative_prompt_embeds, _, _ = pipe.encode_prompt(prompt=prompt, negative_prompt=neg_prompt, device=device, num_images_per_prompt=1)
return prompt_embeds, None, negative_prompt_embeds, None
if "HiDreamImage" in pipe.__class__.__name__: # clip is only used for the pooled embeds
prompt_embeds_t5, negative_prompt_embeds_t5, prompt_embeds_llama3, negative_prompt_embeds_llama3, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3, prompt_4=prompt_4,
@ -662,6 +667,8 @@ def get_xhinker_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", cl
prompt_embed, negative_embed, positive_pooled, negative_pooled = get_weighted_text_embeddings_sd3(pipe=pipe, prompt=prompt, neg_prompt=neg_prompt, use_t5_encoder=bool(pipe.text_encoder_3))
elif 'Flux' in pipe.__class__.__name__:
prompt_embed, positive_pooled = get_weighted_text_embeddings_flux1(pipe=pipe, prompt=prompt, prompt2=prompt_2, device=devices.device)
elif 'Chroma' in pipe.__class__.__name__:
prompt_embed, negative_embed = get_weighted_text_embeddings_chroma(pipe=pipe, prompt=prompt, neg_prompt=neg_prompt, device=devices.device)
elif 'XL' in pipe.__class__.__name__:
prompt_embed, negative_embed, positive_pooled, negative_pooled = get_weighted_text_embeddings_sdxl_2p(pipe=pipe, prompt=prompt, prompt_2=prompt_2, neg_prompt=neg_prompt, neg_prompt_2=neg_prompt_2)
else:

View File

@ -17,11 +17,13 @@
## -----------------------------------------------------------------------------
import torch
import torch.nn.functional as F
from transformers import CLIPTokenizer, T5Tokenizer
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionXLPipeline
from diffusers import StableDiffusion3Pipeline
from diffusers import FluxPipeline
from diffusers import ChromaPipeline
from modules.prompt_parser import parse_prompt_attention # use built-in A1111 parser
@ -1424,3 +1426,75 @@ def get_weighted_text_embeddings_flux1(
t5_prompt_embeds = t5_prompt_embeds.to(dtype=pipe.text_encoder_2.dtype, device=device)
return t5_prompt_embeds, prompt_embeds
def get_weighted_text_embeddings_chroma(
pipe: ChromaPipeline,
prompt: str = "",
neg_prompt: str = "",
device=None
):
"""
This function can process long prompt with weights for Chroma model
Args:
pipe (ChromaPipeline)
prompt (str)
neg_prompt (str)
device (torch.device, optional): Device to run the embeddings on.
Returns:
prompt_embeds (T5 prompt embeds as torch.Tensor)
neg_prompt_embeds (T5 prompt embeds as torch.Tensor)
"""
if device is None:
device = pipe.text_encoder.device
# prompt
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights_t5(
pipe.tokenizer, prompt
)
prompt_tokens = torch.tensor([prompt_tokens], dtype=torch.long)
t5_prompt_embeds = pipe.text_encoder(prompt_tokens.to(device))[0].squeeze(0)
t5_prompt_embeds = t5_prompt_embeds.to(device=device)
# add weight to t5 prompt
for z in range(len(prompt_weights)):
if prompt_weights[z] != 1.0:
t5_prompt_embeds[z] = t5_prompt_embeds[z] * prompt_weights[z]
t5_prompt_embeds = t5_prompt_embeds.unsqueeze(0)
t5_prompt_embeds = t5_prompt_embeds.to(dtype=pipe.text_encoder.dtype, device=device)
# negative prompt
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights_t5(
pipe.tokenizer, neg_prompt
)
neg_prompt_tokens = torch.tensor([neg_prompt_tokens], dtype=torch.long)
t5_neg_prompt_embeds = pipe.text_encoder(neg_prompt_tokens.to(device))[0].squeeze(0)
t5_neg_prompt_embeds = t5_neg_prompt_embeds.to(device=device)
# add weight to neg t5 embeddings
for z in range(len(neg_prompt_weights)):
if neg_prompt_weights[z] != 1.0:
t5_neg_prompt_embeds[z] = t5_neg_prompt_embeds[z] * neg_prompt_weights[z]
t5_neg_prompt_embeds = t5_neg_prompt_embeds.unsqueeze(0)
t5_neg_prompt_embeds = t5_neg_prompt_embeds.to(dtype=pipe.text_encoder.dtype, device=device)
def pad_prompt_embeds_to_same_size(prompt_embeds_a, prompt_embeds_b):
size_a = prompt_embeds_a.size(1)
size_b = prompt_embeds_b.size(1)
if size_a < size_b:
pad_size = size_b - size_a
prompt_embeds_a = F.pad(prompt_embeds_a, (0, 0, 0, pad_size)) # Pad dim=1
elif size_b < size_a:
pad_size = size_a - size_b
prompt_embeds_b = F.pad(prompt_embeds_b, (0, 0, 0, pad_size)) # Pad dim=1
return prompt_embeds_a, prompt_embeds_b
# chroma needs positive and negative prompt embeddings to have the same length (for now)
return pad_prompt_embeds_to_same_size(t5_prompt_embeds, t5_neg_prompt_embeds)

View File

@ -92,6 +92,8 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False):
guess = 'Stable Diffusion 3'
if 'hidream' in f.lower():
guess = 'HiDream'
if 'chroma' in f.lower():
guess = 'Chroma'
if 'flux' in f.lower() or 'flex.1' in f.lower() or 'lodestones' in f.lower():
guess = 'FLUX'
if size > 11000 and size < 16000:

View File

@ -329,6 +329,9 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op='
elif model_type in ['FLEX']:
from modules.model_flex import load_flex
sd_model = load_flex(checkpoint_info, diffusers_load_config)
elif model_type in ['Chroma']:
from modules.model_chroma import load_chroma
sd_model = load_chroma(checkpoint_info, diffusers_load_config)
elif model_type in ['Lumina 2']:
from modules.model_lumina import load_lumina2
sd_model = load_lumina2(checkpoint_info, diffusers_load_config)

View File

@ -11,7 +11,7 @@ from modules.timer import process as process_timer
debug = os.environ.get('SD_MOVE_DEBUG', None) is not None
debug_move = log.trace if debug else lambda *args, **kwargs: None
offload_warn = ['sc', 'sd3', 'f1', 'h1', 'hunyuandit', 'auraflow', 'omnigen', 'cogview4']
offload_warn = ['sc', 'sd3', 'f1', 'h1', 'hunyuandit', 'auraflow', 'omnigen', 'cogview4', 'chroma']
offload_post = ['h1']
offload_hook_instance = None
balanced_offload_exclude = ['OmniGenPipeline', 'CogView4Pipeline']

View File

@ -12,7 +12,7 @@ samplers = all_samplers
samplers_for_img2img = all_samplers
samplers_map = {}
loaded_config = None
flow_models = ['Flux', 'StableDiffusion3', 'Lumina', 'AuraFlow', 'Sana', 'CogView4', 'HiDream']
flow_models = ['Flux', 'StableDiffusion3', 'Lumina', 'AuraFlow', 'Sana', 'CogView4', 'HiDream', 'Chroma']
flow_models += ['Hunyuan', 'LTX', 'Mochi']

View File

@ -9,7 +9,7 @@ from modules import shared, devices, processing, images, sd_vae_approx, sd_vae_t
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
approximation_indexes = { "Simple": 0, "Approximate": 1, "TAESD": 2, "Full VAE": 3 }
flow_models = ['f1', 'sd3', 'lumina', 'auraflow', 'sana', 'lumina2', 'cogview4', 'h1']
flow_models = ['f1', 'sd3', 'lumina', 'auraflow', 'sana', 'lumina2', 'cogview4', 'h1', 'chroma']
warned = False
queue_lock = threading.Lock()

View File

@ -12,6 +12,7 @@ hf_decode_endpoints = {
'sd': 'https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud',
'sdxl': 'https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud',
'f1': 'https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud',
'chroma': 'https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud',
'h1': 'https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud',
'hunyuanvideo': 'https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud',
}
@ -19,6 +20,7 @@ hf_encode_endpoints = {
'sd': 'https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud',
'sdxl': 'https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud',
'f1': 'https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud',
'chroma': 'https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud',
}
dtypes = {
"float16": torch.float16,
@ -48,7 +50,7 @@ def remote_decode(latents: torch.Tensor, width: int = 0, height: int = 0, model_
params = {}
try:
latent = latent_copy[i]
if model_type != 'f1':
if model_type not in ['f1', 'chroma']:
latent = latent.unsqueeze(0)
params = {
"input_tensor_type": "binary",
@ -74,7 +76,7 @@ def remote_decode(latents: torch.Tensor, width: int = 0, height: int = 0, model_
params["output_type"] = "pt"
params["output_tensor_type"] = "binary"
headers["Accept"] = "tensor/binary"
if (model_type == 'f1' or model_type == 'h1') and (width > 0) and (height > 0):
if (model_type in ['f1', 'h1', 'chroma']) and (width > 0) and (height > 0):
params['width'] = width
params['height'] = height
if shared.sd_model.vae is not None and shared.sd_model.vae.config is not None:

View File

@ -55,7 +55,7 @@ def get_model(model_type = 'decoder', variant = None):
cls = shared.sd_model_type
if cls in {'ldm', 'pixartalpha'}:
cls = 'sd'
elif cls in {'h1', 'lumina2'}:
elif cls in {'h1', 'lumina2', 'chroma'}:
cls = 'f1'
elif cls == 'pixartsigma':
cls = 'sdxl'

View File

@ -27,6 +27,7 @@ pipelines = {
'DeepFloyd IF': getattr(diffusers, 'IFPipeline', None),
'FLUX': getattr(diffusers, 'FluxPipeline', None),
'FLEX': getattr(diffusers, 'AutoPipelineForText2Image', None),
'Chroma': getattr(diffusers, 'ChromaPipeline', None),
'Sana': getattr(diffusers, 'SanaPipeline', None),
'Lumina-Next': getattr(diffusers, 'LuminaText2ImgPipeline', None),
'Lumina 2': getattr(diffusers, 'Lumina2Pipeline', None),

View File

@ -4,9 +4,10 @@ from .teacache_lumina2 import teacache_lumina2_forward
from .teacache_ltx import teacache_ltx_forward
from .teacache_mochi import teacache_mochi_forward
from .teacache_cogvideox import teacache_cog_forward
from .teacache_chroma import teacache_chroma_forward
supported_models = ['Flux', 'CogVideoX', 'Mochi', 'LTX', 'HiDream', 'Lumina2']
supported_models = ['Flux', 'Chroma', 'CogVideoX', 'Mochi', 'LTX', 'HiDream', 'Lumina2']
def apply_teacache(p):

View File

@ -0,0 +1,325 @@
from typing import Any, Dict, Optional, Union
import torch
import numpy as np
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def teacache_chroma_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`ChromaTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
input_vec = self.time_text_embed(timestep)
pooled_temb = self.distilled_guidance_layer(input_vec)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
if self.enable_teacache:
inp = hidden_states.clone()
input_vec_ = input_vec.clone()
modulated_inp, _gate_msa, _shift_mlp, _scale_mlp, _gate_mlp = self.transformer_blocks[0].norm1(inp, emb=input_vec_)
if self.cnt == 0 or self.cnt == self.num_steps-1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.cnt += 1
if self.cnt == self.num_steps:
self.cnt = 0
if self.enable_teacache:
if not should_calc:
hidden_states += self.previous_residual
else:
ori_hidden_states = hidden_states.clone()
for index_block, block in enumerate(self.transformer_blocks):
img_offset = 3 * len(self.single_transformer_blocks)
txt_offset = img_offset + 6 * len(self.transformer_blocks)
img_modulation = img_offset + 6 * index_block
text_modulation = txt_offset + 6 * index_block
temb = torch.cat(
(
pooled_temb[:, img_modulation : img_modulation + 6],
pooled_temb[:, text_modulation : text_modulation + 6],
),
dim=1,
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward4(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward4(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
start_idx = 3 * index_block
temb = pooled_temb[:, start_idx : start_idx + 3]
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward2(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward2(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
self.previous_residual = hidden_states - ori_hidden_states
else:
for index_block, block in enumerate(self.transformer_blocks):
img_offset = 3 * len(self.single_transformer_blocks)
txt_offset = img_offset + 6 * len(self.transformer_blocks)
img_modulation = img_offset + 6 * index_block
text_modulation = txt_offset + 6 * index_block
temb = torch.cat(
(
pooled_temb[:, img_modulation : img_modulation + 6],
pooled_temb[:, text_modulation : text_modulation + 6],
),
dim=1,
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward1(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward1(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
start_idx = 3 * index_block
temb = pooled_temb[:, start_idx : start_idx + 3]
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward3(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward3(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
temb = pooled_temb[:, -2:]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)