diff --git a/tile_utils/attn.py b/tile_utils/attn.py index a6f2522..5a74a36 100644 --- a/tile_utils/attn.py +++ b/tile_utils/attn.py @@ -80,7 +80,7 @@ def xformers_attnblock_forward(self, h_): q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) dtype = q.dtype if shared.opts.upcast_attn: - q, k = q.float(), k.float() + q, k, v = q.float(), k.float(), v.float() q = q.contiguous() k = k.contiguous() v = v.contiguous() @@ -158,7 +158,7 @@ def sdp_attnblock_forward(self, h_): q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) dtype = q.dtype if shared.opts.upcast_attn: - q, k = q.float(), k.float() + q, k, v = q.float(), k.float(), v.float() q = q.contiguous() k = k.contiguous() v = v.contiguous()