Update typing for model data access

pull/4713/head
awsr 2026-03-26 16:42:04 -07:00
parent a6029afb04
commit 775ff59fdb
No known key found for this signature in database
2 changed files with 19 additions and 10 deletions

View File

@ -1,9 +1,16 @@
from __future__ import annotations
import os import os
import sys import sys
import threading import threading
from modules import shared, errors from typing import TYPE_CHECKING
from modules import errors, shared
from modules.logger import log from modules.logger import log
if TYPE_CHECKING:
from diffusers import DiffusionPipeline
def get_model_type(pipe): def get_model_type(pipe):
name = pipe.__class__.__name__ name = pipe.__class__.__name__
@ -120,8 +127,8 @@ def get_model_type(pipe):
class ModelData: class ModelData:
def __init__(self): def __init__(self):
self.sd_model = None self.sd_model: DiffusionPipeline | None = None
self.sd_refiner = None self.sd_refiner: DiffusionPipeline | None = None
self.sd_dict = 'None' self.sd_dict = 'None'
self.initial = True self.initial = True
self.locked = True self.locked = True

View File

@ -236,12 +236,14 @@ def restore_defaults(restart=True):
restart_server(restart) restart_server(restart)
# startup def of shared.sd_model before its redefined in modeldata from modules.modeldata import Shared # pylint: disable=ungrouped-imports
sd_model: DiffusionPipeline | None = None # dummy and overwritten by class
sd_refiner: DiffusionPipeline | None = None # dummy and overwritten by class
sd_model_type: str = '' # dummy and overwritten by class
sd_refiner_type: str = '' # dummy and overwritten by class
sd_loaded: bool = False # dummy and overwritten by class
from modules.modeldata import Shared # pylint: disable=ungrouped-imports
sys.modules[__name__].__class__ = Shared sys.modules[__name__].__class__ = Shared
if TYPE_CHECKING:
# From Shared class
sd_model: DiffusionPipeline | None
sd_refiner: DiffusionPipeline | None
sd_model_type: str
sd_refiner_type: str
sd_loaded: bool