mirror of https://github.com/vladmandic/automatic
RUF013 update
parent
de86927c1b
commit
3e228afa78
|
|
@ -53,13 +53,13 @@ class DeepDanbooru:
|
|||
def tag_multi(
|
||||
self,
|
||||
pil_image,
|
||||
general_threshold: float = None,
|
||||
include_rating: bool = None,
|
||||
exclude_tags: str = None,
|
||||
max_tags: int = None,
|
||||
sort_alpha: bool = None,
|
||||
use_spaces: bool = None,
|
||||
escape_brackets: bool = None,
|
||||
general_threshold: float | None = None,
|
||||
include_rating: bool | None = None,
|
||||
exclude_tags: str | None = None,
|
||||
max_tags: int | None = None,
|
||||
sort_alpha: bool | None = None,
|
||||
use_spaces: bool | None = None,
|
||||
escape_brackets: bool | None = None,
|
||||
):
|
||||
"""Run inference and return formatted tag string.
|
||||
|
||||
|
|
@ -134,7 +134,7 @@ def get_models() -> list:
|
|||
return ["DeepBooru"]
|
||||
|
||||
|
||||
def load_model(model_name: str = None) -> bool: # pylint: disable=unused-argument
|
||||
def load_model(model_name: str = "") -> bool: # pylint: disable=unused-argument
|
||||
"""Load the DeepBooru model."""
|
||||
try:
|
||||
model.load()
|
||||
|
|
|
|||
|
|
@ -53,12 +53,12 @@ class JoyOptions:
|
|||
return f'repo="{self.repo}" temp={self.temp} top_k={self.top_k} top_p={self.top_p} sample={self.sample} tokens={self.max_new_tokens}'
|
||||
|
||||
|
||||
processor: AutoProcessor = None
|
||||
llava_model: LlavaForConditionalGeneration = None
|
||||
processor: AutoProcessor | None = None
|
||||
llava_model: LlavaForConditionalGeneration | None = None
|
||||
opts = JoyOptions()
|
||||
|
||||
|
||||
def load(repo: str = None):
|
||||
def load(repo: str | None = None):
|
||||
"""Load JoyCaption model."""
|
||||
global llava_model, processor # pylint: disable=global-statement
|
||||
repo = repo or opts.repo
|
||||
|
|
@ -93,7 +93,7 @@ def unload():
|
|||
log.debug('JoyCaption unload: no model loaded')
|
||||
|
||||
|
||||
def predict(question: str, image, vqa_model: str = None) -> str:
|
||||
def predict(question: str, image, vqa_model: str | None = None) -> str:
|
||||
opts.max_new_tokens = shared.opts.caption_vlm_max_length
|
||||
load(vqa_model)
|
||||
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ def _image_hash(image: Image.Image) -> str:
|
|||
return h.hexdigest()
|
||||
|
||||
|
||||
def encode_image(image: Image.Image, cache_key: str = None):
|
||||
def encode_image(image: Image.Image, cache_key: str | None = None):
|
||||
"""
|
||||
Encode image for reuse across multiple queries.
|
||||
|
||||
|
|
@ -119,7 +119,7 @@ def encode_image(image: Image.Image, cache_key: str = None):
|
|||
|
||||
|
||||
def query(image: Image.Image, question: str, repo: str, stream: bool = False,
|
||||
temperature: float = None, top_p: float = None, max_tokens: int = None,
|
||||
temperature: float | None = None, top_p: float | None = None, max_tokens: int | None = None,
|
||||
use_cache: bool = False, reasoning: bool = True):
|
||||
"""
|
||||
Visual question answering with optional streaming.
|
||||
|
|
@ -180,7 +180,7 @@ def query(image: Image.Image, question: str, repo: str, stream: bool = False,
|
|||
|
||||
|
||||
def caption(image: Image.Image, repo: str, length: str = 'normal', stream: bool = False,
|
||||
temperature: float = None, top_p: float = None, max_tokens: int = None):
|
||||
temperature: float | None = None, top_p: float | None = None, max_tokens: int | None = None):
|
||||
"""
|
||||
Generate image captions at different lengths.
|
||||
|
||||
|
|
@ -290,8 +290,8 @@ def detect(image: Image.Image, object_name: str, repo: str, max_objects: int = 1
|
|||
return detections
|
||||
|
||||
|
||||
def predict(question: str, image: Image.Image, repo: str, model_name: str = None, thinking_mode: bool = False,
|
||||
mode: str = None, stream: bool = False, use_cache: bool = False, **kwargs):
|
||||
def predict(question: str, image: Image.Image, repo: str, model_name: str | None = None, thinking_mode: bool = False,
|
||||
mode: str | None = None, stream: bool = False, use_cache: bool = False, **kwargs):
|
||||
"""
|
||||
Main entry point for Moondream 3 VQA - auto-detects mode from question.
|
||||
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ def unload_model():
|
|||
waifudiffusion.unload_model()
|
||||
|
||||
|
||||
def tag(image, model_name: str = None, **kwargs) -> str:
|
||||
def tag(image, model_name: str | None = None, **kwargs) -> str:
|
||||
"""Unified tagging - dispatch to correct backend.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ def get_prompts_for_model(model_name: str) -> list:
|
|||
return vlm_prompts_common
|
||||
|
||||
|
||||
def get_internal_prompt(friendly_name: str, user_prompt: str = None) -> str:
|
||||
def get_internal_prompt(friendly_name: str, user_prompt: str | None = None) -> str:
|
||||
"""Convert friendly prompt name to internal token/command."""
|
||||
internal = vlm_prompt_mapping.get(friendly_name, friendly_name)
|
||||
|
||||
|
|
@ -350,7 +350,7 @@ class VQA:
|
|||
self.processor = None
|
||||
devices.torch_gc(force=True, reason='vqa model switch')
|
||||
|
||||
def load(self, model_name: str = None):
|
||||
def load(self, model_name: str | None = None):
|
||||
"""Load VLM model into memory for the specified model name."""
|
||||
model_name = model_name or shared.opts.caption_vlm_model
|
||||
if not model_name:
|
||||
|
|
@ -444,7 +444,7 @@ class VQA:
|
|||
self.loaded = repo
|
||||
devices.torch_gc()
|
||||
|
||||
def _fastvlm(self, question: str, image: Image.Image, repo: str, model_name: str = None):
|
||||
def _fastvlm(self, question: str, image: Image.Image, repo: str, model_name: str | None = None):
|
||||
debug(f'VQA caption: handler=fastvlm model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None}')
|
||||
self._load_fastvlm(repo)
|
||||
move_aux_to_gpu('vqa')
|
||||
|
|
@ -521,7 +521,7 @@ class VQA:
|
|||
self.loaded = repo
|
||||
devices.torch_gc()
|
||||
|
||||
def _qwen(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False):
|
||||
def _qwen(self, question: str, image: Image.Image, repo: str, system_prompt: str | None = None, model_name: str | None = None, prefill: str | None = None, thinking_mode: bool = False):
|
||||
self._load_qwen(repo)
|
||||
move_aux_to_gpu('vqa')
|
||||
# Get model class name for logging
|
||||
|
|
@ -644,7 +644,7 @@ class VQA:
|
|||
self.loaded = repo
|
||||
devices.torch_gc()
|
||||
|
||||
def _gemma(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False):
|
||||
def _gemma(self, question: str, image: Image.Image, repo: str, system_prompt: str | None = None, model_name: str | None = None, prefill: str | None = None, thinking_mode: bool = False):
|
||||
self._load_gemma(repo)
|
||||
move_aux_to_gpu('vqa')
|
||||
# Get model class name for logging
|
||||
|
|
@ -757,7 +757,7 @@ class VQA:
|
|||
self.loaded = repo
|
||||
devices.torch_gc()
|
||||
|
||||
def _mistral(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False):
|
||||
def _mistral(self, question: str, image: Image.Image, repo: str, system_prompt: str | None = None, model_name: str | None = None, prefill: str | None = None, thinking_mode: bool = False):
|
||||
self._load_mistral(repo)
|
||||
move_aux_to_gpu('vqa')
|
||||
cls_name = self.model.__class__.__name__
|
||||
|
|
@ -828,7 +828,7 @@ class VQA:
|
|||
self.loaded = repo
|
||||
devices.torch_gc()
|
||||
|
||||
def _paligemma(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
|
||||
def _paligemma(self, question: str, image: Image.Image, repo: str, model_name: str | None = None): # pylint: disable=unused-argument
|
||||
self._load_paligemma(repo)
|
||||
move_aux_to_gpu('vqa')
|
||||
question = question.replace('<', '').replace('>', '').replace('_', ' ')
|
||||
|
|
@ -870,7 +870,7 @@ class VQA:
|
|||
self.loaded = repo
|
||||
devices.torch_gc()
|
||||
|
||||
def _ovis(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
|
||||
def _ovis(self, question: str, image: Image.Image, repo: str, model_name: str | None = None): # pylint: disable=unused-argument
|
||||
try:
|
||||
pass # pylint: disable=unused-import
|
||||
except Exception:
|
||||
|
|
@ -925,7 +925,7 @@ class VQA:
|
|||
self.loaded = repo
|
||||
devices.torch_gc()
|
||||
|
||||
def _smol(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False):
|
||||
def _smol(self, question: str, image: Image.Image, repo: str, system_prompt: str | None = None, model_name: str | None = None, prefill: str | None = None, thinking_mode: bool = False):
|
||||
self._load_smol(repo)
|
||||
move_aux_to_gpu('vqa')
|
||||
# Get model class name for logging
|
||||
|
|
@ -1019,7 +1019,7 @@ class VQA:
|
|||
self.loaded = repo
|
||||
devices.torch_gc()
|
||||
|
||||
def _git(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
|
||||
def _git(self, question: str, image: Image.Image, repo: str, model_name: str | None = None): # pylint: disable=unused-argument
|
||||
self._load_git(repo)
|
||||
move_aux_to_gpu('vqa')
|
||||
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
|
||||
|
|
@ -1053,7 +1053,7 @@ class VQA:
|
|||
self.loaded = repo
|
||||
devices.torch_gc()
|
||||
|
||||
def _blip(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
|
||||
def _blip(self, question: str, image: Image.Image, repo: str, model_name: str | None = None): # pylint: disable=unused-argument
|
||||
self._load_blip(repo)
|
||||
move_aux_to_gpu('vqa')
|
||||
inputs = self.processor(image, question, return_tensors="pt")
|
||||
|
|
@ -1081,7 +1081,7 @@ class VQA:
|
|||
self.loaded = repo
|
||||
devices.torch_gc()
|
||||
|
||||
def _vilt(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
|
||||
def _vilt(self, question: str, image: Image.Image, repo: str, model_name: str | None = None): # pylint: disable=unused-argument
|
||||
self._load_vilt(repo)
|
||||
move_aux_to_gpu('vqa')
|
||||
inputs = self.processor(image, question, return_tensors="pt")
|
||||
|
|
@ -1111,7 +1111,7 @@ class VQA:
|
|||
self.loaded = repo
|
||||
devices.torch_gc()
|
||||
|
||||
def _pix(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
|
||||
def _pix(self, question: str, image: Image.Image, repo: str, model_name: str | None = None): # pylint: disable=unused-argument
|
||||
self._load_pix(repo)
|
||||
move_aux_to_gpu('vqa')
|
||||
if len(question) > 0:
|
||||
|
|
@ -1144,7 +1144,7 @@ class VQA:
|
|||
register_aux('vqa', self.model)
|
||||
devices.torch_gc()
|
||||
|
||||
def _moondream(self, question: str, image: Image.Image, repo: str, model_name: str = None, thinking_mode: bool = False):
|
||||
def _moondream(self, question: str, image: Image.Image, repo: str, model_name: str | None = None, thinking_mode: bool = False):
|
||||
debug(f'VQA caption: handler=moondream model_name="{model_name}" repo="{repo}" question="{question}" thinking_mode={thinking_mode}')
|
||||
self._load_moondream(repo)
|
||||
move_aux_to_gpu('vqa')
|
||||
|
|
@ -1207,7 +1207,7 @@ class VQA:
|
|||
# When keep_thinking is False, just use the answer (reasoning is discarded)
|
||||
return response
|
||||
|
||||
def _load_florence(self, repo: str, revision: str = None):
|
||||
def _load_florence(self, repo: str, revision: str | None = None):
|
||||
"""Load Florence-2 model and processor."""
|
||||
_get_imports = transformers.dynamic_module_utils.get_imports
|
||||
|
||||
|
|
@ -1247,7 +1247,7 @@ class VQA:
|
|||
self.loaded = cache_key
|
||||
devices.torch_gc()
|
||||
|
||||
def _florence(self, question: str, image: Image.Image, repo: str, revision: str = None, model_name: str = None): # pylint: disable=unused-argument
|
||||
def _florence(self, question: str, image: Image.Image, repo: str, revision: str | None = None, model_name: str | None = None): # pylint: disable=unused-argument
|
||||
self._load_florence(repo, revision)
|
||||
move_aux_to_gpu('vqa')
|
||||
if question.startswith('<'):
|
||||
|
|
@ -1306,7 +1306,7 @@ class VQA:
|
|||
self.loaded = repo
|
||||
devices.torch_gc()
|
||||
|
||||
def _sa2(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument
|
||||
def _sa2(self, question: str, image: Image.Image, repo: str, model_name: str | None = None): # pylint: disable=unused-argument
|
||||
self._load_sa2(repo)
|
||||
move_aux_to_gpu('vqa')
|
||||
if question.startswith('<'):
|
||||
|
|
@ -1325,7 +1325,18 @@ class VQA:
|
|||
response = return_dict["prediction"] # the text format answer
|
||||
return response
|
||||
|
||||
def caption(self, question: str = '', system_prompt: str = None, prompt: str = None, image: Image.Image = None, model_name: str = None, prefill: str = None, thinking_mode: bool = None, quiet: bool = False, generation_kwargs: dict = None) -> str:
|
||||
def caption(
|
||||
self,
|
||||
question: str = "",
|
||||
system_prompt: str | None = None,
|
||||
prompt: str | None = None,
|
||||
image: list[Image.Image] | Image.Image | dict | None = None,
|
||||
model_name: str | None = None,
|
||||
prefill: str | None = None,
|
||||
thinking_mode: bool | None = None,
|
||||
quiet: bool = False,
|
||||
generation_kwargs: dict | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Main entry point for VQA captioning. Returns string answer.
|
||||
Detection data stored in self.last_detection_data for annotated image creation.
|
||||
|
|
@ -1596,7 +1607,7 @@ def unload_model():
|
|||
return get_instance().unload()
|
||||
|
||||
|
||||
def load_model(model_name: str = None):
|
||||
def load_model(model_name: str | None = None):
|
||||
return get_instance().load(model_name)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def parse_points(result) -> list:
|
|||
return points
|
||||
|
||||
|
||||
def parse_detections(result, label: str, max_objects: int = None) -> list:
|
||||
def parse_detections(result, label: str, max_objects: int | None = None) -> list:
|
||||
"""Parse and validate detection bboxes from model result.
|
||||
|
||||
Args:
|
||||
|
|
@ -74,7 +74,7 @@ def parse_detections(result, label: str, max_objects: int = None) -> list:
|
|||
return detections
|
||||
|
||||
|
||||
def parse_florence_detections(response, image_size: tuple = None) -> list:
|
||||
def parse_florence_detections(response, image_size: tuple | None = None) -> list:
|
||||
"""Parse Florence-style detection response into standard detection format.
|
||||
|
||||
Florence returns detection data in two possible formats:
|
||||
|
|
@ -286,7 +286,7 @@ def calculate_eye_position(face_bbox: dict) -> tuple:
|
|||
return (eye_x, eye_y)
|
||||
|
||||
|
||||
def draw_bounding_boxes(image: Image.Image, detections: list, points: list = None) -> Image.Image:
|
||||
def draw_bounding_boxes(image: Image.Image, detections: list, points: list | None = None) -> Image.Image:
|
||||
"""
|
||||
Draw bounding boxes and/or points on an image.
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class WaifuDiffusionTagger:
|
|||
self.model_path = None
|
||||
self.image_size = 448 # Standard for WD models
|
||||
|
||||
def load(self, model_name: str = None):
|
||||
def load(self, model_name: str | None = None):
|
||||
"""Load the ONNX model and tags from HuggingFace."""
|
||||
import huggingface_hub
|
||||
|
||||
|
|
@ -195,15 +195,15 @@ class WaifuDiffusionTagger:
|
|||
|
||||
def predict(
|
||||
self,
|
||||
image: Image.Image,
|
||||
general_threshold: float = None,
|
||||
character_threshold: float = None,
|
||||
include_rating: bool = None,
|
||||
exclude_tags: str = None,
|
||||
max_tags: int = None,
|
||||
sort_alpha: bool = None,
|
||||
use_spaces: bool = None,
|
||||
escape_brackets: bool = None,
|
||||
image: Image.Image | list[Image.Image] | dict | None,
|
||||
general_threshold: float | None = None,
|
||||
character_threshold: float | None = None,
|
||||
include_rating: bool | None = None,
|
||||
exclude_tags: str | None = None,
|
||||
max_tags: int | None = None,
|
||||
sort_alpha: bool | None = None,
|
||||
use_spaces: bool | None = None,
|
||||
escape_brackets: bool | None = None,
|
||||
) -> str:
|
||||
"""Run inference and return formatted tag string.
|
||||
|
||||
|
|
@ -352,7 +352,7 @@ def refresh_models() -> list:
|
|||
return get_models()
|
||||
|
||||
|
||||
def load_model(model_name: str = None) -> bool:
|
||||
def load_model(model_name: str | None = None) -> bool:
|
||||
"""Load the specified WaifuDiffusion model."""
|
||||
return tagger.load(model_name)
|
||||
|
||||
|
|
@ -362,7 +362,7 @@ def unload_model():
|
|||
tagger.unload()
|
||||
|
||||
|
||||
def tag(image: Image.Image, model_name: str = None, **kwargs) -> str:
|
||||
def tag(image: Image.Image, model_name: str | None = None, **kwargs) -> str:
|
||||
"""Tag an image using WaifuDiffusion tagger.
|
||||
|
||||
Args:
|
||||
|
|
|
|||
Loading…
Reference in New Issue