diff --git a/annotator/clipvision/__init__.py b/annotator/clipvision/__init__.py index bf47d6a..a6ed308 100644 --- a/annotator/clipvision/__init__.py +++ b/annotator/clipvision/__init__.py @@ -4,7 +4,7 @@ import torch from modules import devices from modules.modelloader import load_file_from_url from annotator.annotator_path import models_path -from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils +from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor config_clip_g = { @@ -77,6 +77,10 @@ downloads = { } +clip_vision_h_uc = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'clip_vision_h_uc.data') +clip_vision_h_uc = torch.load(clip_vision_h_uc)['uc'] + + class ClipVisionDetector: def __init__(self, config): assert config in downloads diff --git a/annotator/clipvision/clip_vision_h_uc.data b/annotator/clipvision/clip_vision_h_uc.data new file mode 100644 index 0000000..70c4a7b Binary files /dev/null and b/annotator/clipvision/clip_vision_h_uc.data differ diff --git a/scripts/controlmodel_ipadapter.py b/scripts/controlmodel_ipadapter.py index fbd4181..2b4bb67 100644 --- a/scripts/controlmodel_ipadapter.py +++ b/scripts/controlmodel_ipadapter.py @@ -200,8 +200,9 @@ class IPAdapterModel(torch.nn.Module): self.image_proj_model.cpu() if self.is_plus: + from annotator.clipvision import clip_vision_h_uc cond = self.image_proj_model(clip_vision_output['hidden_states'][-2].to(device='cpu', dtype=torch.float32)) - uncond = self.image_proj_model(torch.zeros_like(clip_vision_output['hidden_states'][-2].to(device='cpu', dtype=torch.float32))) + uncond = self.image_proj_model(clip_vision_h_uc.to(cond)) return cond, uncond clip_image_embeds = clip_vision_output['image_embeds'].to(device='cpu', dtype=torch.float32) diff --git a/scripts/controlnet_version.py b/scripts/controlnet_version.py index 0a82e8f..04d6a08 100644 --- a/scripts/controlnet_version.py +++ b/scripts/controlnet_version.py @@ -1,4 +1,4 @@ -version_flag = 'v1.1.406' +version_flag = 'v1.1.407' from scripts.logging import logger