From 794875cadb9ccf6c245e4b0d3fcb517bf2cc4e1e Mon Sep 17 00:00:00 2001 From: Chengsong Zhang Date: Mon, 20 Nov 2023 23:27:03 -0600 Subject: [PATCH] Fix fp8 (#333) * fix ) * fix ipadapter --- scripts/animatediff_cn.py | 6 ++++-- scripts/animatediff_mm.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/scripts/animatediff_cn.py b/scripts/animatediff_cn.py index c1b0d2a..ed55660 100644 --- a/scripts/animatediff_cn.py +++ b/scripts/animatediff_cn.py @@ -245,8 +245,10 @@ class AnimateDiffControl: else: model_net = cn_script.load_control_model(p, unet, unit.model) model_net.reset() - if model_net is not None and getattr(devices, "fp8", False): - model_net.to(torch.float8_e4m3fn) + if model_net is not None and getattr(devices, "fp8", False) and not isinstance(model_net, PlugableIPAdapter): + for _module in model_net.modules(): + if isinstance(_module, (torch.nn.Conv2d, torch.nn.Linear)): + _module.to(torch.float8_e4m3fn) if getattr(model_net, 'is_control_lora', False): control_lora = model_net.control_model diff --git a/scripts/animatediff_mm.py b/scripts/animatediff_mm.py index f9a12e9..8516b0f 100644 --- a/scripts/animatediff_mm.py +++ b/scripts/animatediff_mm.py @@ -49,7 +49,7 @@ class AnimateDiffMM: self.mm.half() if getattr(devices, "fp8", False): for module in self.mm.modules(): - if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear))): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): module.to(torch.float8_e4m3fn)