IP2P model loading support.
This is the code to load the model and inference it with only a text prompt. This commit does not contain the nodes to properly use it with an image input. This supports both the original SD1 instructpix2pix model and the diffusers SDXL one.pull/3183/head^2
parent
96b4c757cf
commit
575acb69e4
|
|
@ -473,6 +473,40 @@ class SD_X4Upscaler(BaseModel):
|
|||
out['y'] = comfy.conds.CONDRegular(noise_level)
|
||||
return out
|
||||
|
||||
class IP2P:
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
if image.shape[1:] != noise.shape[1:]:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
out['c_concat'] = comfy.conds.CONDNoiseShape(self.process_ip2p_image_in(image))
|
||||
adm = self.encode_adm(**kwargs)
|
||||
if adm is not None:
|
||||
out['y'] = comfy.conds.CONDRegular(adm)
|
||||
return out
|
||||
|
||||
class SD15_instructpix2pix(IP2P, BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.process_ip2p_image_in = lambda image: image
|
||||
|
||||
class SDXL_instructpix2pix(IP2P, SDXL):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
# self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image)
|
||||
self.process_ip2p_image_in = lambda image: image
|
||||
|
||||
|
||||
class StableCascade_C(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
||||
|
|
|
|||
|
|
@ -182,9 +182,9 @@ def detect_unet_config(state_dict, key_prefix):
|
|||
|
||||
return unet_config
|
||||
|
||||
def model_config_from_unet_config(unet_config):
|
||||
def model_config_from_unet_config(unet_config, state_dict=None):
|
||||
for model_config in comfy.supported_models.models:
|
||||
if model_config.matches(unet_config):
|
||||
if model_config.matches(unet_config, state_dict):
|
||||
return model_config(unet_config)
|
||||
|
||||
logging.error("no match {}".format(unet_config))
|
||||
|
|
@ -192,7 +192,7 @@ def model_config_from_unet_config(unet_config):
|
|||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix)
|
||||
model_config = model_config_from_unet_config(unet_config)
|
||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
return comfy.supported_models_base.BASE(unet_config)
|
||||
else:
|
||||
|
|
@ -321,6 +321,12 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
|
||||
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
|
||||
|
|
@ -351,7 +357,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS]
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SDXL_diffusers_ip2p]
|
||||
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
|
|
|
|||
|
|
@ -334,6 +334,11 @@ class Stable_Zero123(supported_models_base.BASE):
|
|||
"num_head_channels": -1,
|
||||
}
|
||||
|
||||
required_keys = {
|
||||
"cc_projection.weight": None,
|
||||
"cc_projection.bias": None,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "cond_stage_model.model.visual."
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
|
|
@ -439,6 +444,33 @@ class Stable_Cascade_B(Stable_Cascade_C):
|
|||
out = model_base.StableCascade_B(self, device=device)
|
||||
return out
|
||||
|
||||
class SD15_instructpix2pix(SD15):
|
||||
unet_config = {
|
||||
"context_dim": 768,
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": False,
|
||||
"adm_in_channels": None,
|
||||
"use_temporal_attention": False,
|
||||
"in_channels": 8,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SD15_instructpix2pix(self, device=device)
|
||||
|
||||
class SDXL_instructpix2pix(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816,
|
||||
"use_temporal_attention": False,
|
||||
"in_channels": 8,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SDXL_instructpix2pix(self, device=device)
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
|
||||
|
||||
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p]
|
||||
models += [SVD_img2vid]
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ class BASE:
|
|||
"num_head_channels": 64,
|
||||
}
|
||||
|
||||
required_keys = {}
|
||||
|
||||
clip_prefix = []
|
||||
clip_vision_prefix = None
|
||||
noise_aug_config = None
|
||||
|
|
@ -28,10 +30,14 @@ class BASE:
|
|||
manual_cast_dtype = None
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config):
|
||||
def matches(s, unet_config, state_dict=None):
|
||||
for k in s.unet_config:
|
||||
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
||||
return False
|
||||
if state_dict is not None:
|
||||
for k in s.required_keys:
|
||||
if k not in state_dict:
|
||||
return False
|
||||
return True
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
|
|
|
|||
Loading…
Reference in New Issue