sd-webui-prompt-all-in-one/scripts/physton_prompt/get_token_counter.py

51 lines
1.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from modules import script_callbacks, extra_networks, prompt_parser, sd_models
from functools import reduce
# 嘗試載入 sd_hijack / model_hijackA1111 / Forge Classic 才有)
try:
from modules.sd_hijack import model_hijack
except (ImportError, ModuleNotFoundError):
model_hijack = None
def get_token_counter(text, steps):
try:
try:
text, _ = extra_networks.parse_prompt(text)
_, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])
prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)
except Exception:
prompt_schedules = [[[steps, text]]]
# 判斷是否 Forge
try:
from modules_forge import forge_version
forge = True
except:
forge = False
flat_prompts = reduce(lambda list1, list2: list1 + list2, prompt_schedules)
prompts = [prompt_text for step, prompt_text in flat_prompts]
# 🚨 Forge Neo / 沒有 hijack直接停用 token counter
if model_hijack is None:
return {"token_count": 0, "max_length": 0}
# A1111 / Forge Classic
if forge:
cond_stage_model = sd_models.model_data.sd_model.cond_stage_model
token_count, max_length = max(
[model_hijack.get_prompt_lengths(prompt, cond_stage_model) for prompt in prompts],
key=lambda args: args[0]
)
else:
token_count, max_length = max(
[model_hijack.get_prompt_lengths(prompt) for prompt in prompts],
key=lambda args: args[0]
)
return {"token_count": token_count, "max_length": max_length}
except Exception:
return {"token_count": 0, "max_length": 0}