commit
f9f8073e64
|
|
@ -80,7 +80,7 @@ def xformers_attnblock_forward(self, h_):
|
|||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
|
|
@ -158,7 +158,7 @@ def sdp_attnblock_forward(self, h_):
|
|||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
dtype = q.dtype
|
||||
if shared.opts.upcast_attn:
|
||||
q, k = q.float(), k.float()
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
|
|
|
|||
Loading…
Reference in New Issue