better model type guess (#374)

pull/401/head v1.13.0
Chengsong Zhang 2023-12-19 21:31:21 -06:00 committed by GitHub
parent 83dc4d0f0c
commit 55ecff3656
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 9 deletions

View File

@ -208,9 +208,7 @@ AnimateDiff in img2img batch will be available in [v1.10.0](https://github.com/c
[Download](https://huggingface.co/conrevo/AnimateDiff-A1111/tree/main/lora) and use them like any other LoRA you use (example: download motion lora to `stable-diffusion-webui/models/Lora` and add `<lora:mm_sd15_v2_lora_PanLeft:0.8>` to your positive prompt). **Motion LoRA only supports V2 motion modules**.
### V3
V3 has identical state dict keys as V1 but slightly different inference logic (GroupNorm is not hacked for V3). This extension identifies V3 via checking "v3" and "sd15" are substrings of the model filename (for example, both `v3_sd15_mm.ckpt` and `mm_sd15_v3.safetensors` contain `v3` and `sd15`). You should NOT change the filename of the official V3 motion module (either from my link or from the official link), and you should make sure that filenames of V3 community models contain both `v3` and `sd15`; filenames of V1 community models cannot contain `v3` and `sd15` at the same time. Other motion modules are identified by guessing from the state dict, so they are not affected.
You may optionally use [adapter](https://huggingface.co/conrevo/AnimateDiff-A1111/resolve/main/lora/mm_sd15_v3_adapter.safetensors?download=true) for V3, in the same way as the way you use LoRA. You MUST use [my link](https://huggingface.co/conrevo/AnimateDiff-A1111/resolve/main/lora/mm_sd15_v3_adapter.safetensors?download=true) instead of the [official link](https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_adapter.ckpt?download=true). The official adapter won't work for A1111 due to state dict incompatibility.
V3 has identical state dict keys as V1 but slightly different inference logic (GroupNorm is not hacked for V3). You may optionally use [adapter](https://huggingface.co/conrevo/AnimateDiff-A1111/resolve/main/lora/mm_sd15_v3_adapter.safetensors?download=true) for V3, in the same way as the way you use LoRA. You MUST use [my link](https://huggingface.co/conrevo/AnimateDiff-A1111/resolve/main/lora/mm_sd15_v3_adapter.safetensors?download=true) instead of the [official link](https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_adapter.ckpt?download=true). The official adapter won't work for A1111 due to state dict incompatibility.
### SDXL
[AnimateDiffXL](https://github.com/guoyww/AnimateDiff/tree/sdxl) and [HotShot-XL](https://github.com/hotshotco/Hotshot-XL) have identical architecture to AnimateDiff-SD1.5. The only 2 difference are

View File

@ -21,14 +21,14 @@ class MotionModuleType(Enum):
@staticmethod
def get_mm_type(state_dict: dict, filename: str = None):
def get_mm_type(state_dict: dict[str, torch.Tensor]):
keys = list(state_dict.keys())
if any(["mid_block" in k for k in keys]):
return MotionModuleType.AnimateDiffV2
elif any(["temporal_attentions" in k for k in keys]):
return MotionModuleType.HotShotXL
elif any(["down_blocks.3" in k for k in keys]):
if "v3" in filename and "sd15" in filename:
if 32 in next((state_dict[key] for key in state_dict if 'pe' in key), None).shape:
return MotionModuleType.AnimateDiffV3
else:
return MotionModuleType.AnimateDiffV1
@ -52,7 +52,7 @@ class MotionWrapper(nn.Module):
self.is_adxl = mm_type == MotionModuleType.AnimateDiffXL
self.is_xl = self.is_hotshot or self.is_adxl
max_len = 32 if (self.is_v2 or self.is_adxl or self.is_v3) else 24
in_channels = (320, 640, 1280) if (self.is_hotshot or self.is_adxl) else (320, 640, 1280, 1280)
in_channels = (320, 640, 1280) if (self.is_xl) else (320, 640, 1280, 1280)
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
for c in in_channels:
@ -65,6 +65,10 @@ class MotionWrapper(nn.Module):
self.mm_hash = mm_hash
def enable_gn_hack(self):
return not (self.is_adxl or self.is_v3)
class MotionModule(nn.Module):
def __init__(self, in_channels, num_mm, max_len, is_hotshot=False):
super().__init__()

View File

@ -39,7 +39,7 @@ class AnimateDiffMM:
logger.info(f"Loading motion module {model_name} from {model_path}")
model_hash = hashes.sha256(model_path, f"AnimateDiff/{model_name}")
mm_state_dict = sd_models.read_state_dict(model_path)
model_type = MotionModuleType.get_mm_type(mm_state_dict, model_name)
model_type = MotionModuleType.get_mm_type(mm_state_dict)
logger.info(f"Guessed {model_name} architecture: {model_type}")
self.mm = MotionWrapper(model_name, model_hash, model_type)
missed_keys = self.mm.load_state_dict(mm_state_dict)
@ -67,7 +67,7 @@ class AnimateDiffMM:
if self.mm.is_v2:
logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet middle block.")
unet.middle_block.insert(-1, self.mm.mid_block.motion_modules[0])
elif not (self.mm.is_adxl or self.mm.is_v3):
elif self.mm.enable_gn_hack():
logger.info(f"Hacking {sd_ver} GroupNorm32 forward function.")
if self.mm.is_hotshot:
from sgm.modules.diffusionmodules.util import GroupNorm32
@ -137,7 +137,7 @@ class AnimateDiffMM:
if self.mm.is_v2:
logger.info(f"Removing motion module from {sd_ver} UNet middle block.")
unet.middle_block.pop(-2)
elif not (self.mm.is_adxl or self.mm.is_v3):
elif self.mm.enable_gn_hack():
logger.info(f"Restoring {sd_ver} GroupNorm32 forward function.")
if self.mm.is_hotshot:
from sgm.modules.diffusionmodules.util import GroupNorm32