Fix ModuleNotFoundError for sd_hijack

by KidouEita
pull/377/head
hirorohi 2026-04-20 16:34:50 +09:00
parent 596a931725
commit 2aa33fd12a
1 changed files with 22 additions and 16 deletions

View File

@ -1,44 +1,50 @@
from modules import script_callbacks, extra_networks, prompt_parser, sd_models from modules import script_callbacks, extra_networks, prompt_parser, sd_models
from modules.sd_hijack import model_hijack from functools import reduce
from functools import partial, reduce
# 嘗試載入 sd_hijack / model_hijackA1111 / Forge Classic 才有)
try:
from modules.sd_hijack import model_hijack
except (ImportError, ModuleNotFoundError):
model_hijack = None
def get_token_counter(text, steps): def get_token_counter(text, steps):
# FIX: Use try-except to safely handle PyTorch/model access errors (TypeError NoneType)
# that occur during model loading/switching when the token counter API is triggered.
try: try:
# copy from modules.ui.py
try: try:
text, _ = extra_networks.parse_prompt(text) text, _ = extra_networks.parse_prompt(text)
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
except Exception: except Exception:
# a parsing error can happen here during typing, and we don't want to bother the user with
# messages related to it in console
prompt_schedules = [[[steps, text]]] prompt_schedules = [[[steps, text]]]
# 判斷是否 Forge
try: try:
from modules_forge import forge_version from modules_forge import forge_version
forge = True forge = True
except: except:
forge = False forge = False
flat_prompts = reduce(lambda list1, list2: list1 + list2, prompt_schedules) flat_prompts = reduce(lambda list1, list2: list1 + list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts] prompts = [prompt_text for step, prompt_text in flat_prompts]
# 🚨 Forge Neo / 沒有 hijack直接停用 token counter
if model_hijack is None:
return {"token_count": 0, "max_length": 0}
# A1111 / Forge Classic
if forge: if forge:
cond_stage_model = sd_models.model_data.sd_model.cond_stage_model cond_stage_model = sd_models.model_data.sd_model.cond_stage_model
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt,cond_stage_model) for prompt in prompts], token_count, max_length = max(
key=lambda args: args[0]) [model_hijack.get_prompt_lengths(prompt, cond_stage_model) for prompt in prompts],
key=lambda args: args[0]
)
else: else:
token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], token_count, max_length = max(
key=lambda args: args[0]) [model_hijack.get_prompt_lengths(prompt) for prompt in prompts],
key=lambda args: args[0]
)
return {"token_count": token_count, "max_length": max_length} return {"token_count": token_count, "max_length": max_length}
except Exception as e: except Exception:
# return 0 token count if any error (model instability, parsing error, etc.) occurs during calculation
return {"token_count": 0, "max_length": 0} return {"token_count": 0, "max_length": 0}