Added support for medvram/lowvram

main
CodeExplode 2022-11-09 06:26:24 +11:00 committed by GitHub
parent 90d709e281
commit ef4a8d462d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 1 deletions

View File

@ -38,8 +38,14 @@ def determine_embedding_distribution():
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
# fix for medvram/lowvram - can't figure out how to detect the device of the model in torch, so will try to guess from the web ui options
device = devices.device
if cmd_opts.medvram or cmd_opts.lowvram:
device = torch.device("cpu")
#
for i in range(49405): # guessing that's the range of CLIP tokens given that 49406 and 49407 are special tokens presumably appended to the end
embedding = embedding_layer.token_embedding.wrapped(torch.LongTensor([i]).to(devices.device)).squeeze(0)
embedding = embedding_layer.token_embedding.wrapped(torch.LongTensor([i]).to(device)).squeeze(0)
if i == 0:
distribution_floor = embedding
distribution_ceiling = embedding