diff --git a/scripts/animatediff_cn.py b/scripts/animatediff_cn.py index c1b0d2a..ed55660 100644 --- a/scripts/animatediff_cn.py +++ b/scripts/animatediff_cn.py @@ -245,8 +245,10 @@ class AnimateDiffControl: else: model_net = cn_script.load_control_model(p, unet, unit.model) model_net.reset() - if model_net is not None and getattr(devices, "fp8", False): - model_net.to(torch.float8_e4m3fn) + if model_net is not None and getattr(devices, "fp8", False) and not isinstance(model_net, PlugableIPAdapter): + for _module in model_net.modules(): + if isinstance(_module, (torch.nn.Conv2d, torch.nn.Linear)): + _module.to(torch.float8_e4m3fn) if getattr(model_net, 'is_control_lora', False): control_lora = model_net.control_model diff --git a/scripts/animatediff_mm.py b/scripts/animatediff_mm.py index f9a12e9..8516b0f 100644 --- a/scripts/animatediff_mm.py +++ b/scripts/animatediff_mm.py @@ -49,7 +49,7 @@ class AnimateDiffMM: self.mm.half() if getattr(devices, "fp8", False): for module in self.mm.modules(): - if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear))): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): module.to(torch.float8_e4m3fn)