110 lines
3.4 KiB
Python
110 lines
3.4 KiB
Python
import torch
|
|
import os
|
|
import functools
|
|
import base64
|
|
import numpy as np
|
|
import gradio as gr
|
|
|
|
from typing import Any, Callable, Dict
|
|
|
|
|
|
def load_state_dict(ckpt_path, location="cpu"):
|
|
_, extension = os.path.splitext(ckpt_path)
|
|
if extension.lower() == ".safetensors":
|
|
import safetensors.torch
|
|
|
|
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
|
else:
|
|
state_dict = get_state_dict(
|
|
torch.load(ckpt_path, map_location=torch.device(location))
|
|
)
|
|
state_dict = get_state_dict(state_dict)
|
|
print(f"Loaded state_dict from [{ckpt_path}]")
|
|
return state_dict
|
|
|
|
|
|
def get_state_dict(d):
|
|
return d.get("state_dict", d)
|
|
|
|
|
|
def ndarray_lru_cache(max_size: int = 128, typed: bool = False):
|
|
"""
|
|
Decorator to enable caching for functions with numpy array arguments.
|
|
Numpy arrays are mutable, and thus not directly usable as hash keys.
|
|
|
|
The idea here is to wrap the incoming arguments with type `np.ndarray`
|
|
as `HashableNpArray` so that `lru_cache` can correctly handles `np.ndarray`
|
|
arguments.
|
|
|
|
`HashableNpArray` functions exactly the same way as `np.ndarray` except
|
|
having `__hash__` and `__eq__` overriden.
|
|
"""
|
|
|
|
def decorator(func: Callable):
|
|
"""The actual decorator that accept function as input."""
|
|
|
|
class HashableNpArray(np.ndarray):
|
|
def __new__(cls, input_array):
|
|
# Input array is an instance of ndarray.
|
|
# The view makes the input array and returned array share the same data.
|
|
obj = np.asarray(input_array).view(cls)
|
|
return obj
|
|
|
|
def __eq__(self, other) -> bool:
|
|
return np.array_equal(self, other)
|
|
|
|
def __hash__(self):
|
|
# Hash the bytes representing the data of the array.
|
|
return hash(self.tobytes())
|
|
|
|
@functools.lru_cache(maxsize=max_size, typed=typed)
|
|
def cached_func(*args, **kwargs):
|
|
"""This function only accepts `HashableNpArray` as input params."""
|
|
return func(*args, **kwargs)
|
|
|
|
# Preserves original function.__name__ and __doc__.
|
|
@functools.wraps(func)
|
|
def decorated_func(*args, **kwargs):
|
|
"""The decorated function that delegates the original function."""
|
|
|
|
def convert_item(item: Any):
|
|
return HashableNpArray(item) if isinstance(item, np.ndarray) else item
|
|
|
|
args = [convert_item(arg) for arg in args]
|
|
kwargs = {k: convert_item(arg) for k, arg in kwargs.items()}
|
|
return cached_func(*args, **kwargs)
|
|
|
|
return decorated_func
|
|
|
|
return decorator
|
|
|
|
|
|
# svgsupports
|
|
svgsupport = False
|
|
try:
|
|
import io
|
|
from svglib.svglib import svg2rlg
|
|
from reportlab.graphics import renderPM
|
|
|
|
svgsupport = True
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def svg_preprocess(inputs: Dict, preprocess: Callable):
|
|
if not inputs:
|
|
return None
|
|
|
|
if inputs["image"].startswith("data:image/svg+xml;base64,") and svgsupport:
|
|
svg_data = base64.b64decode(
|
|
inputs["image"].replace("data:image/svg+xml;base64,", "")
|
|
)
|
|
drawing = svg2rlg(io.BytesIO(svg_data))
|
|
png_data = renderPM.drawToString(drawing, fmt="PNG")
|
|
encoded_string = base64.b64encode(png_data)
|
|
base64_str = str(encoded_string, "utf-8")
|
|
base64_str = "data:image/png;base64," + base64_str
|
|
inputs["image"] = base64_str
|
|
return preprocess(inputs)
|
|
|