pull/1109/head
Vladimir Mandic 2023-05-20 08:29:02 -04:00
parent e8ddc6eada
commit 0891b30ffe
3 changed files with 7 additions and 1 deletions

1
.gitignore vendored
View File

@ -32,6 +32,7 @@ venv
/*.sh
/*.txt
/*.mp3
/*.lnk
!webui.bat
!webui.sh

View File

@ -25,6 +25,7 @@ Stuff to be added...
Stuff to be investigated...
- Gradio `app_kwargs`: <https://github.com/gradio-app/gradio/issues/4054>
## Merge PRs

View File

@ -223,15 +223,19 @@ def read_state_dict(checkpoint_file, map_location=None): # pylint: disable=unuse
_, extension = os.path.splitext(checkpoint_file)
if shared.opts.stream_load:
if extension.lower() == ".safetensors":
shared.log.debug('Model weights loading: type=safetensors mode=buffered')
buffer = f.read()
pl_sd = safetensors.torch.load(buffer)
else:
shared.log.debug('Model weights loading: type=checkpoint mode=buffered')
buffer = io.BytesIO(f.read())
pl_sd = torch.load(buffer, map_location='cpu')
else:
if extension.lower() == ".safetensors":
shared.log.debug('Model weights loading: type=safetensors mode=mmap')
pl_sd = safetensors.torch.load_file(checkpoint_file, device='cpu')
else:
shared.log.debug('Model weights loading: type=checkpoint mode=direct')
pl_sd = torch.load(f, map_location='cpu')
sd = get_state_dict_from_checkpoint(pl_sd)
del pl_sd
@ -244,7 +248,7 @@ def read_state_dict(checkpoint_file, map_location=None): # pylint: disable=unuse
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
if checkpoint_info in checkpoints_loaded:
# use checkpoint cache
shared.log.info("Loading weights from cache")
shared.log.info("Model weights loading: from cache")
return checkpoints_loaded[checkpoint_info]
res = read_state_dict(checkpoint_info.filename)
timer.record("load")