diff --git a/scripts/fabric.py b/scripts/fabric.py index 46b09ce..b6754dc 100644 --- a/scripts/fabric.py +++ b/scripts/fabric.py @@ -27,7 +27,7 @@ except ImportError: from modules.ui import create_refresh_button -__version__ = "0.6.2" +__version__ = "0.6.3" DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1") diff --git a/scripts/weighted_attention.py b/scripts/weighted_attention.py index 4263632..0092845 100644 --- a/scripts/weighted_attention.py +++ b/scripts/weighted_attention.py @@ -65,17 +65,18 @@ def weighted_attention(self, attn_fn, x, context=None, weights=None, **kwargs): def _get_attn_bias(weights, shape=None, dtype=torch.float32): # shape of weights needs to be divisible by 8 in order for xformers attn bias to work last_dim = ((weights.shape[-1] - 1) // 8 + 1) * 8 - w_bias = torch.zeros(weights.shape[:-1] + (last_dim,), device=weights.device, dtype=weights.dtype) + w_bias = torch.zeros(weights.shape[:-1] + (last_dim,), device=weights.device, dtype=dtype) min_val = torch.finfo(dtype).min - w_bias[..., :weights.shape[-1]] = weights.log().clamp(min=min_val) + w_bias[..., :weights.shape[-1]] = weights.log().to(dtype=dtype).clamp(min=min_val) if shape is not None: assert shape[-1] == weights.shape[-1], "Last dimension of shape must match last dimension of weights (number of keys)" w_bias = w_bias.view([1] * (len(shape) - 1) + [-1]).expand(shape[:-1] + (last_dim,)) + # make sure not to consolidate the tensor after expanding, + # as it will lead to a stride overflow for large numbers of feedback images - # cast first in order to preserve multiple-of-8 stride - w_bias = w_bias.to(dtype=dtype) + # slice in order to preserve multiple-of-8 stride w_bias = w_bias[..., :weights.shape[-1]] return w_bias