mirror of https://github.com/vladmandic/automatic
update
parent
e8ddc6eada
commit
0891b30ffe
|
|
@ -32,6 +32,7 @@ venv
|
|||
/*.sh
|
||||
/*.txt
|
||||
/*.mp3
|
||||
/*.lnk
|
||||
!webui.bat
|
||||
!webui.sh
|
||||
|
||||
|
|
|
|||
1
TODO.md
1
TODO.md
|
|
@ -25,6 +25,7 @@ Stuff to be added...
|
|||
|
||||
Stuff to be investigated...
|
||||
|
||||
- Gradio `app_kwargs`: <https://github.com/gradio-app/gradio/issues/4054>
|
||||
|
||||
## Merge PRs
|
||||
|
||||
|
|
|
|||
|
|
@ -223,15 +223,19 @@ def read_state_dict(checkpoint_file, map_location=None): # pylint: disable=unuse
|
|||
_, extension = os.path.splitext(checkpoint_file)
|
||||
if shared.opts.stream_load:
|
||||
if extension.lower() == ".safetensors":
|
||||
shared.log.debug('Model weights loading: type=safetensors mode=buffered')
|
||||
buffer = f.read()
|
||||
pl_sd = safetensors.torch.load(buffer)
|
||||
else:
|
||||
shared.log.debug('Model weights loading: type=checkpoint mode=buffered')
|
||||
buffer = io.BytesIO(f.read())
|
||||
pl_sd = torch.load(buffer, map_location='cpu')
|
||||
else:
|
||||
if extension.lower() == ".safetensors":
|
||||
shared.log.debug('Model weights loading: type=safetensors mode=mmap')
|
||||
pl_sd = safetensors.torch.load_file(checkpoint_file, device='cpu')
|
||||
else:
|
||||
shared.log.debug('Model weights loading: type=checkpoint mode=direct')
|
||||
pl_sd = torch.load(f, map_location='cpu')
|
||||
sd = get_state_dict_from_checkpoint(pl_sd)
|
||||
del pl_sd
|
||||
|
|
@ -244,7 +248,7 @@ def read_state_dict(checkpoint_file, map_location=None): # pylint: disable=unuse
|
|||
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
|
||||
if checkpoint_info in checkpoints_loaded:
|
||||
# use checkpoint cache
|
||||
shared.log.info("Loading weights from cache")
|
||||
shared.log.info("Model weights loading: from cache")
|
||||
return checkpoints_loaded[checkpoint_info]
|
||||
res = read_state_dict(checkpoint_info.filename)
|
||||
timer.record("load")
|
||||
|
|
|
|||
Loading…
Reference in New Issue