mirror of https://github.com/vladmandic/automatic
Update typing for model data access
parent
a6029afb04
commit
775ff59fdb
|
|
@ -1,9 +1,16 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from modules import shared, errors
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from modules import errors, shared
|
||||||
from modules.logger import log
|
from modules.logger import log
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
def get_model_type(pipe):
|
def get_model_type(pipe):
|
||||||
name = pipe.__class__.__name__
|
name = pipe.__class__.__name__
|
||||||
|
|
@ -120,8 +127,8 @@ def get_model_type(pipe):
|
||||||
|
|
||||||
class ModelData:
|
class ModelData:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.sd_model = None
|
self.sd_model: DiffusionPipeline | None = None
|
||||||
self.sd_refiner = None
|
self.sd_refiner: DiffusionPipeline | None = None
|
||||||
self.sd_dict = 'None'
|
self.sd_dict = 'None'
|
||||||
self.initial = True
|
self.initial = True
|
||||||
self.locked = True
|
self.locked = True
|
||||||
|
|
|
||||||
|
|
@ -236,12 +236,14 @@ def restore_defaults(restart=True):
|
||||||
restart_server(restart)
|
restart_server(restart)
|
||||||
|
|
||||||
|
|
||||||
# startup def of shared.sd_model before its redefined in modeldata
|
from modules.modeldata import Shared # pylint: disable=ungrouped-imports
|
||||||
sd_model: DiffusionPipeline | None = None # dummy and overwritten by class
|
|
||||||
sd_refiner: DiffusionPipeline | None = None # dummy and overwritten by class
|
|
||||||
sd_model_type: str = '' # dummy and overwritten by class
|
|
||||||
sd_refiner_type: str = '' # dummy and overwritten by class
|
|
||||||
sd_loaded: bool = False # dummy and overwritten by class
|
|
||||||
|
|
||||||
from modules.modeldata import Shared # pylint: disable=ungrouped-imports
|
|
||||||
sys.modules[__name__].__class__ = Shared
|
sys.modules[__name__].__class__ = Shared
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# From Shared class
|
||||||
|
sd_model: DiffusionPipeline | None
|
||||||
|
sd_refiner: DiffusionPipeline | None
|
||||||
|
sd_model_type: str
|
||||||
|
sd_refiner_type: str
|
||||||
|
sd_loaded: bool
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue