parent
83dc4d0f0c
commit
55ecff3656
|
|
@ -208,9 +208,7 @@ AnimateDiff in img2img batch will be available in [v1.10.0](https://github.com/c
|
|||
[Download](https://huggingface.co/conrevo/AnimateDiff-A1111/tree/main/lora) and use them like any other LoRA you use (example: download motion lora to `stable-diffusion-webui/models/Lora` and add `<lora:mm_sd15_v2_lora_PanLeft:0.8>` to your positive prompt). **Motion LoRA only supports V2 motion modules**.
|
||||
|
||||
### V3
|
||||
V3 has identical state dict keys as V1 but slightly different inference logic (GroupNorm is not hacked for V3). This extension identifies V3 via checking "v3" and "sd15" are substrings of the model filename (for example, both `v3_sd15_mm.ckpt` and `mm_sd15_v3.safetensors` contain `v3` and `sd15`). You should NOT change the filename of the official V3 motion module (either from my link or from the official link), and you should make sure that filenames of V3 community models contain both `v3` and `sd15`; filenames of V1 community models cannot contain `v3` and `sd15` at the same time. Other motion modules are identified by guessing from the state dict, so they are not affected.
|
||||
|
||||
You may optionally use [adapter](https://huggingface.co/conrevo/AnimateDiff-A1111/resolve/main/lora/mm_sd15_v3_adapter.safetensors?download=true) for V3, in the same way as the way you use LoRA. You MUST use [my link](https://huggingface.co/conrevo/AnimateDiff-A1111/resolve/main/lora/mm_sd15_v3_adapter.safetensors?download=true) instead of the [official link](https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_adapter.ckpt?download=true). The official adapter won't work for A1111 due to state dict incompatibility.
|
||||
V3 has identical state dict keys as V1 but slightly different inference logic (GroupNorm is not hacked for V3). You may optionally use [adapter](https://huggingface.co/conrevo/AnimateDiff-A1111/resolve/main/lora/mm_sd15_v3_adapter.safetensors?download=true) for V3, in the same way as the way you use LoRA. You MUST use [my link](https://huggingface.co/conrevo/AnimateDiff-A1111/resolve/main/lora/mm_sd15_v3_adapter.safetensors?download=true) instead of the [official link](https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_adapter.ckpt?download=true). The official adapter won't work for A1111 due to state dict incompatibility.
|
||||
|
||||
### SDXL
|
||||
[AnimateDiffXL](https://github.com/guoyww/AnimateDiff/tree/sdxl) and [HotShot-XL](https://github.com/hotshotco/Hotshot-XL) have identical architecture to AnimateDiff-SD1.5. The only 2 difference are
|
||||
|
|
|
|||
|
|
@ -21,14 +21,14 @@ class MotionModuleType(Enum):
|
|||
|
||||
|
||||
@staticmethod
|
||||
def get_mm_type(state_dict: dict, filename: str = None):
|
||||
def get_mm_type(state_dict: dict[str, torch.Tensor]):
|
||||
keys = list(state_dict.keys())
|
||||
if any(["mid_block" in k for k in keys]):
|
||||
return MotionModuleType.AnimateDiffV2
|
||||
elif any(["temporal_attentions" in k for k in keys]):
|
||||
return MotionModuleType.HotShotXL
|
||||
elif any(["down_blocks.3" in k for k in keys]):
|
||||
if "v3" in filename and "sd15" in filename:
|
||||
if 32 in next((state_dict[key] for key in state_dict if 'pe' in key), None).shape:
|
||||
return MotionModuleType.AnimateDiffV3
|
||||
else:
|
||||
return MotionModuleType.AnimateDiffV1
|
||||
|
|
@ -52,7 +52,7 @@ class MotionWrapper(nn.Module):
|
|||
self.is_adxl = mm_type == MotionModuleType.AnimateDiffXL
|
||||
self.is_xl = self.is_hotshot or self.is_adxl
|
||||
max_len = 32 if (self.is_v2 or self.is_adxl or self.is_v3) else 24
|
||||
in_channels = (320, 640, 1280) if (self.is_hotshot or self.is_adxl) else (320, 640, 1280, 1280)
|
||||
in_channels = (320, 640, 1280) if (self.is_xl) else (320, 640, 1280, 1280)
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
for c in in_channels:
|
||||
|
|
@ -65,6 +65,10 @@ class MotionWrapper(nn.Module):
|
|||
self.mm_hash = mm_hash
|
||||
|
||||
|
||||
def enable_gn_hack(self):
|
||||
return not (self.is_adxl or self.is_v3)
|
||||
|
||||
|
||||
class MotionModule(nn.Module):
|
||||
def __init__(self, in_channels, num_mm, max_len, is_hotshot=False):
|
||||
super().__init__()
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class AnimateDiffMM:
|
|||
logger.info(f"Loading motion module {model_name} from {model_path}")
|
||||
model_hash = hashes.sha256(model_path, f"AnimateDiff/{model_name}")
|
||||
mm_state_dict = sd_models.read_state_dict(model_path)
|
||||
model_type = MotionModuleType.get_mm_type(mm_state_dict, model_name)
|
||||
model_type = MotionModuleType.get_mm_type(mm_state_dict)
|
||||
logger.info(f"Guessed {model_name} architecture: {model_type}")
|
||||
self.mm = MotionWrapper(model_name, model_hash, model_type)
|
||||
missed_keys = self.mm.load_state_dict(mm_state_dict)
|
||||
|
|
@ -67,7 +67,7 @@ class AnimateDiffMM:
|
|||
if self.mm.is_v2:
|
||||
logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet middle block.")
|
||||
unet.middle_block.insert(-1, self.mm.mid_block.motion_modules[0])
|
||||
elif not (self.mm.is_adxl or self.mm.is_v3):
|
||||
elif self.mm.enable_gn_hack():
|
||||
logger.info(f"Hacking {sd_ver} GroupNorm32 forward function.")
|
||||
if self.mm.is_hotshot:
|
||||
from sgm.modules.diffusionmodules.util import GroupNorm32
|
||||
|
|
@ -137,7 +137,7 @@ class AnimateDiffMM:
|
|||
if self.mm.is_v2:
|
||||
logger.info(f"Removing motion module from {sd_ver} UNet middle block.")
|
||||
unet.middle_block.pop(-2)
|
||||
elif not (self.mm.is_adxl or self.mm.is_v3):
|
||||
elif self.mm.enable_gn_hack():
|
||||
logger.info(f"Restoring {sd_ver} GroupNorm32 forward function.")
|
||||
if self.mm.is_hotshot:
|
||||
from sgm.modules.diffusionmodules.util import GroupNorm32
|
||||
|
|
|
|||
Loading…
Reference in New Issue