v0.6.3 (bugfix)
- fixes a stride overflow error for large numbers of feedback imagespull/38/head
parent
43aa3ce962
commit
5a247c9d9e
|
|
@ -27,7 +27,7 @@ except ImportError:
|
|||
from modules.ui import create_refresh_button
|
||||
|
||||
|
||||
__version__ = "0.6.2"
|
||||
__version__ = "0.6.3"
|
||||
|
||||
DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1")
|
||||
|
||||
|
|
|
|||
|
|
@ -65,17 +65,18 @@ def weighted_attention(self, attn_fn, x, context=None, weights=None, **kwargs):
|
|||
def _get_attn_bias(weights, shape=None, dtype=torch.float32):
|
||||
# shape of weights needs to be divisible by 8 in order for xformers attn bias to work
|
||||
last_dim = ((weights.shape[-1] - 1) // 8 + 1) * 8
|
||||
w_bias = torch.zeros(weights.shape[:-1] + (last_dim,), device=weights.device, dtype=weights.dtype)
|
||||
w_bias = torch.zeros(weights.shape[:-1] + (last_dim,), device=weights.device, dtype=dtype)
|
||||
|
||||
min_val = torch.finfo(dtype).min
|
||||
w_bias[..., :weights.shape[-1]] = weights.log().clamp(min=min_val)
|
||||
w_bias[..., :weights.shape[-1]] = weights.log().to(dtype=dtype).clamp(min=min_val)
|
||||
|
||||
if shape is not None:
|
||||
assert shape[-1] == weights.shape[-1], "Last dimension of shape must match last dimension of weights (number of keys)"
|
||||
w_bias = w_bias.view([1] * (len(shape) - 1) + [-1]).expand(shape[:-1] + (last_dim,))
|
||||
# make sure not to consolidate the tensor after expanding,
|
||||
# as it will lead to a stride overflow for large numbers of feedback images
|
||||
|
||||
# cast first in order to preserve multiple-of-8 stride
|
||||
w_bias = w_bias.to(dtype=dtype)
|
||||
# slice in order to preserve multiple-of-8 stride
|
||||
w_bias = w_bias[..., :weights.shape[-1]]
|
||||
return w_bias
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue