diff --git a/scripts/physton_prompt/get_token_counter.py b/scripts/physton_prompt/get_token_counter.py index 169f6e5..57bb097 100644 --- a/scripts/physton_prompt/get_token_counter.py +++ b/scripts/physton_prompt/get_token_counter.py @@ -1,44 +1,50 @@ from modules import script_callbacks, extra_networks, prompt_parser, sd_models -from modules.sd_hijack import model_hijack -from functools import partial, reduce +from functools import reduce + +# 嘗試載入 sd_hijack / model_hijack(A1111 / Forge Classic 才有) +try: + from modules.sd_hijack import model_hijack +except (ImportError, ModuleNotFoundError): + model_hijack = None def get_token_counter(text, steps): - # FIX: Use try-except to safely handle PyTorch/model access errors (TypeError NoneType) - # that occur during model loading/switching when the token counter API is triggered. try: - # copy from modules.ui.py try: text, _ = extra_networks.parse_prompt(text) - _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) - except Exception: - # a parsing error can happen here during typing, and we don't want to bother the user with - # messages related to it in console prompt_schedules = [[[steps, text]]] + # 判斷是否 Forge try: from modules_forge import forge_version forge = True - except: forge = False flat_prompts = reduce(lambda list1, list2: list1 + list2, prompt_schedules) prompts = [prompt_text for step, prompt_text in flat_prompts] + # 🚨 Forge Neo / 沒有 hijack:直接停用 token counter + if model_hijack is None: + return {"token_count": 0, "max_length": 0} + + # A1111 / Forge Classic if forge: cond_stage_model = sd_models.model_data.sd_model.cond_stage_model - token_count, max_length = max([model_hijack.get_prompt_lengths(prompt,cond_stage_model) for prompt in prompts], - key=lambda args: args[0]) + token_count, max_length = max( + [model_hijack.get_prompt_lengths(prompt, cond_stage_model) for prompt in prompts], + key=lambda args: args[0] + ) else: - token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], - key=lambda args: args[0]) + token_count, max_length = max( + [model_hijack.get_prompt_lengths(prompt) for prompt in prompts], + key=lambda args: args[0] + ) return {"token_count": token_count, "max_length": max_length} - except Exception as e: - # return 0 token count if any error (model instability, parsing error, etc.) occurs during calculation + except Exception: return {"token_count": 0, "max_length": 0}