v0.6.3 (bugfix)

- fixes a stride overflow error for large numbers of feedback images
pull/38/head
dvruette 2023-09-30 15:46:03 +02:00
parent 43aa3ce962
commit 5a247c9d9e
2 changed files with 6 additions and 5 deletions

View File

@ -27,7 +27,7 @@ except ImportError:
from modules.ui import create_refresh_button
__version__ = "0.6.2"
__version__ = "0.6.3"
DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1")

View File

@ -65,17 +65,18 @@ def weighted_attention(self, attn_fn, x, context=None, weights=None, **kwargs):
def _get_attn_bias(weights, shape=None, dtype=torch.float32):
# shape of weights needs to be divisible by 8 in order for xformers attn bias to work
last_dim = ((weights.shape[-1] - 1) // 8 + 1) * 8
w_bias = torch.zeros(weights.shape[:-1] + (last_dim,), device=weights.device, dtype=weights.dtype)
w_bias = torch.zeros(weights.shape[:-1] + (last_dim,), device=weights.device, dtype=dtype)
min_val = torch.finfo(dtype).min
w_bias[..., :weights.shape[-1]] = weights.log().clamp(min=min_val)
w_bias[..., :weights.shape[-1]] = weights.log().to(dtype=dtype).clamp(min=min_val)
if shape is not None:
assert shape[-1] == weights.shape[-1], "Last dimension of shape must match last dimension of weights (number of keys)"
w_bias = w_bias.view([1] * (len(shape) - 1) + [-1]).expand(shape[:-1] + (last_dim,))
# make sure not to consolidate the tensor after expanding,
# as it will lead to a stride overflow for large numbers of feedback images
# cast first in order to preserve multiple-of-8 stride
w_bias = w_bias.to(dtype=dtype)
# slice in order to preserve multiple-of-8 stride
w_bias = w_bias[..., :weights.shape[-1]]
return w_bias