Hypernetwork-MonkeyPatch-Ex.../patches/external_pr/sd_hijack_checkpoint.py

23 lines
754 B
Python

from torch.utils.checkpoint import checkpoint
def BasicTransformerBlock_forward(self, x, context=None):
return checkpoint(self._forward, x, context)
def AttentionBlock_forward(self, x):
return checkpoint(self._forward, x)
def ResBlock_forward(self, x, emb):
return checkpoint(self._forward, x, emb)
try:
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward
ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward
ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward
except:
pass