diff --git a/modules/modeldata.py b/modules/modeldata.py index 3c22c32b5..fa0d94f64 100644 --- a/modules/modeldata.py +++ b/modules/modeldata.py @@ -1,9 +1,16 @@ +from __future__ import annotations + import os import sys import threading -from modules import shared, errors +from typing import TYPE_CHECKING + +from modules import errors, shared from modules.logger import log +if TYPE_CHECKING: + from diffusers import DiffusionPipeline + def get_model_type(pipe): name = pipe.__class__.__name__ @@ -120,8 +127,8 @@ def get_model_type(pipe): class ModelData: def __init__(self): - self.sd_model = None - self.sd_refiner = None + self.sd_model: DiffusionPipeline | None = None + self.sd_refiner: DiffusionPipeline | None = None self.sd_dict = 'None' self.initial = True self.locked = True diff --git a/modules/shared.py b/modules/shared.py index 0a9a20a9a..1a292c3dc 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,12 +236,14 @@ def restore_defaults(restart=True): restart_server(restart) -# startup def of shared.sd_model before its redefined in modeldata -sd_model: DiffusionPipeline | None = None # dummy and overwritten by class -sd_refiner: DiffusionPipeline | None = None # dummy and overwritten by class -sd_model_type: str = '' # dummy and overwritten by class -sd_refiner_type: str = '' # dummy and overwritten by class -sd_loaded: bool = False # dummy and overwritten by class +from modules.modeldata import Shared # pylint: disable=ungrouped-imports -from modules.modeldata import Shared # pylint: disable=ungrouped-imports sys.modules[__name__].__class__ = Shared + +if TYPE_CHECKING: + # From Shared class + sd_model: DiffusionPipeline | None + sd_refiner: DiffusionPipeline | None + sd_model_type: str + sd_refiner_type: str + sd_loaded: bool