automatic/modules/sd_hijack_safetensors.py

57 lines
2.0 KiB
Python

import safetensors.torch
import transformers
from installer import install, log
from modules import errors
orig_load_file = safetensors.torch.load_file
orig_load_state_dict = transformers.modeling_utils.load_state_dict
def hijacked_load_file(checkpoint_file, device="cpu"):
if not checkpoint_file.endswith('.safetensors'):
return orig_load_file(checkpoint_file, device=device)
install('runai_model_streamer>=0.15.1')
state_dict = {}
from runai_model_streamer import SafetensorsStreamer
try:
with SafetensorsStreamer() as streamer:
streamer.stream_file(checkpoint_file)
for key, tensor in streamer.get_tensors():
state_dict[key] = tensor.to(device)
except Exception as e:
log.error(f'Loader: {e}')
errors.display(e, 'runai')
return state_dict
def hijacked_load_state_dict(checkpoint_file, is_quantized: bool = False, map_location: str = "cpu", weights_only: bool = True):
if not checkpoint_file.endswith(".safetensors"):
return orig_load_state_dict(checkpoint_file=checkpoint_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only)
install('runai_model_streamer>=0.15.1')
state_dict = {}
from runai_model_streamer import SafetensorsStreamer
try:
with SafetensorsStreamer() as streamer:
streamer.stream_file(checkpoint_file)
for key, tensor in streamer.get_tensors():
state_dict[key] = tensor.to(map_location) if map_location != "meta" else tensor
except Exception as e:
log.error(f'Loader: {e}')
errors.display(e, 'runai')
return state_dict
def hijack_safetensors(_diffusers: bool = True, _transformers: bool = True):
if _diffusers:
safetensors.torch.load_file = hijacked_load_file
if _transformers:
transformers.modeling_utils.load_state_dict = hijacked_load_state_dict
def restore_safetensors():
safetensors.torch.load_file = orig_load_file
transformers.modeling_utils.load_state_dict = orig_load_state_dict