diff --git a/modules/caption/deepbooru.py b/modules/caption/deepbooru.py index 6ccb848b1..878b3852a 100644 --- a/modules/caption/deepbooru.py +++ b/modules/caption/deepbooru.py @@ -53,13 +53,13 @@ class DeepDanbooru: def tag_multi( self, pil_image, - general_threshold: float = None, - include_rating: bool = None, - exclude_tags: str = None, - max_tags: int = None, - sort_alpha: bool = None, - use_spaces: bool = None, - escape_brackets: bool = None, + general_threshold: float | None = None, + include_rating: bool | None = None, + exclude_tags: str | None = None, + max_tags: int | None = None, + sort_alpha: bool | None = None, + use_spaces: bool | None = None, + escape_brackets: bool | None = None, ): """Run inference and return formatted tag string. @@ -134,7 +134,7 @@ def get_models() -> list: return ["DeepBooru"] -def load_model(model_name: str = None) -> bool: # pylint: disable=unused-argument +def load_model(model_name: str = "") -> bool: # pylint: disable=unused-argument """Load the DeepBooru model.""" try: model.load() diff --git a/modules/caption/joycaption.py b/modules/caption/joycaption.py index dbee83eba..a18b46800 100644 --- a/modules/caption/joycaption.py +++ b/modules/caption/joycaption.py @@ -53,12 +53,12 @@ class JoyOptions: return f'repo="{self.repo}" temp={self.temp} top_k={self.top_k} top_p={self.top_p} sample={self.sample} tokens={self.max_new_tokens}' -processor: AutoProcessor = None -llava_model: LlavaForConditionalGeneration = None +processor: AutoProcessor | None = None +llava_model: LlavaForConditionalGeneration | None = None opts = JoyOptions() -def load(repo: str = None): +def load(repo: str | None = None): """Load JoyCaption model.""" global llava_model, processor # pylint: disable=global-statement repo = repo or opts.repo @@ -93,7 +93,7 @@ def unload(): log.debug('JoyCaption unload: no model loaded') -def predict(question: str, image, vqa_model: str = None) -> str: +def predict(question: str, image, vqa_model: str | None = None) -> str: opts.max_new_tokens = shared.opts.caption_vlm_max_length load(vqa_model) diff --git a/modules/caption/moondream3.py b/modules/caption/moondream3.py index 20418134c..7859f0f06 100644 --- a/modules/caption/moondream3.py +++ b/modules/caption/moondream3.py @@ -87,7 +87,7 @@ def _image_hash(image: Image.Image) -> str: return h.hexdigest() -def encode_image(image: Image.Image, cache_key: str = None): +def encode_image(image: Image.Image, cache_key: str | None = None): """ Encode image for reuse across multiple queries. @@ -119,7 +119,7 @@ def encode_image(image: Image.Image, cache_key: str = None): def query(image: Image.Image, question: str, repo: str, stream: bool = False, - temperature: float = None, top_p: float = None, max_tokens: int = None, + temperature: float | None = None, top_p: float | None = None, max_tokens: int | None = None, use_cache: bool = False, reasoning: bool = True): """ Visual question answering with optional streaming. @@ -180,7 +180,7 @@ def query(image: Image.Image, question: str, repo: str, stream: bool = False, def caption(image: Image.Image, repo: str, length: str = 'normal', stream: bool = False, - temperature: float = None, top_p: float = None, max_tokens: int = None): + temperature: float | None = None, top_p: float | None = None, max_tokens: int | None = None): """ Generate image captions at different lengths. @@ -290,8 +290,8 @@ def detect(image: Image.Image, object_name: str, repo: str, max_objects: int = 1 return detections -def predict(question: str, image: Image.Image, repo: str, model_name: str = None, thinking_mode: bool = False, - mode: str = None, stream: bool = False, use_cache: bool = False, **kwargs): +def predict(question: str, image: Image.Image, repo: str, model_name: str | None = None, thinking_mode: bool = False, + mode: str | None = None, stream: bool = False, use_cache: bool = False, **kwargs): """ Main entry point for Moondream 3 VQA - auto-detects mode from question. diff --git a/modules/caption/tagger.py b/modules/caption/tagger.py index e73e7df4d..6f2436b38 100644 --- a/modules/caption/tagger.py +++ b/modules/caption/tagger.py @@ -65,7 +65,7 @@ def unload_model(): waifudiffusion.unload_model() -def tag(image, model_name: str = None, **kwargs) -> str: +def tag(image, model_name: str | None = None, **kwargs) -> str: """Unified tagging - dispatch to correct backend. Args: diff --git a/modules/caption/vqa.py b/modules/caption/vqa.py index a10818624..807479144 100644 --- a/modules/caption/vqa.py +++ b/modules/caption/vqa.py @@ -50,7 +50,7 @@ def get_prompts_for_model(model_name: str) -> list: return vlm_prompts_common -def get_internal_prompt(friendly_name: str, user_prompt: str = None) -> str: +def get_internal_prompt(friendly_name: str, user_prompt: str | None = None) -> str: """Convert friendly prompt name to internal token/command.""" internal = vlm_prompt_mapping.get(friendly_name, friendly_name) @@ -350,7 +350,7 @@ class VQA: self.processor = None devices.torch_gc(force=True, reason='vqa model switch') - def load(self, model_name: str = None): + def load(self, model_name: str | None = None): """Load VLM model into memory for the specified model name.""" model_name = model_name or shared.opts.caption_vlm_model if not model_name: @@ -444,7 +444,7 @@ class VQA: self.loaded = repo devices.torch_gc() - def _fastvlm(self, question: str, image: Image.Image, repo: str, model_name: str = None): + def _fastvlm(self, question: str, image: Image.Image, repo: str, model_name: str | None = None): debug(f'VQA caption: handler=fastvlm model_name="{model_name}" repo="{repo}" question="{question}" image_size={image.size if image else None}') self._load_fastvlm(repo) move_aux_to_gpu('vqa') @@ -521,7 +521,7 @@ class VQA: self.loaded = repo devices.torch_gc() - def _qwen(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False): + def _qwen(self, question: str, image: Image.Image, repo: str, system_prompt: str | None = None, model_name: str | None = None, prefill: str | None = None, thinking_mode: bool = False): self._load_qwen(repo) move_aux_to_gpu('vqa') # Get model class name for logging @@ -644,7 +644,7 @@ class VQA: self.loaded = repo devices.torch_gc() - def _gemma(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False): + def _gemma(self, question: str, image: Image.Image, repo: str, system_prompt: str | None = None, model_name: str | None = None, prefill: str | None = None, thinking_mode: bool = False): self._load_gemma(repo) move_aux_to_gpu('vqa') # Get model class name for logging @@ -757,7 +757,7 @@ class VQA: self.loaded = repo devices.torch_gc() - def _mistral(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False): + def _mistral(self, question: str, image: Image.Image, repo: str, system_prompt: str | None = None, model_name: str | None = None, prefill: str | None = None, thinking_mode: bool = False): self._load_mistral(repo) move_aux_to_gpu('vqa') cls_name = self.model.__class__.__name__ @@ -828,7 +828,7 @@ class VQA: self.loaded = repo devices.torch_gc() - def _paligemma(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument + def _paligemma(self, question: str, image: Image.Image, repo: str, model_name: str | None = None): # pylint: disable=unused-argument self._load_paligemma(repo) move_aux_to_gpu('vqa') question = question.replace('<', '').replace('>', '').replace('_', ' ') @@ -870,7 +870,7 @@ class VQA: self.loaded = repo devices.torch_gc() - def _ovis(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument + def _ovis(self, question: str, image: Image.Image, repo: str, model_name: str | None = None): # pylint: disable=unused-argument try: pass # pylint: disable=unused-import except Exception: @@ -925,7 +925,7 @@ class VQA: self.loaded = repo devices.torch_gc() - def _smol(self, question: str, image: Image.Image, repo: str, system_prompt: str = None, model_name: str = None, prefill: str = None, thinking_mode: bool = False): + def _smol(self, question: str, image: Image.Image, repo: str, system_prompt: str | None = None, model_name: str | None = None, prefill: str | None = None, thinking_mode: bool = False): self._load_smol(repo) move_aux_to_gpu('vqa') # Get model class name for logging @@ -1019,7 +1019,7 @@ class VQA: self.loaded = repo devices.torch_gc() - def _git(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument + def _git(self, question: str, image: Image.Image, repo: str, model_name: str | None = None): # pylint: disable=unused-argument self._load_git(repo) move_aux_to_gpu('vqa') pixel_values = self.processor(images=image, return_tensors="pt").pixel_values @@ -1053,7 +1053,7 @@ class VQA: self.loaded = repo devices.torch_gc() - def _blip(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument + def _blip(self, question: str, image: Image.Image, repo: str, model_name: str | None = None): # pylint: disable=unused-argument self._load_blip(repo) move_aux_to_gpu('vqa') inputs = self.processor(image, question, return_tensors="pt") @@ -1081,7 +1081,7 @@ class VQA: self.loaded = repo devices.torch_gc() - def _vilt(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument + def _vilt(self, question: str, image: Image.Image, repo: str, model_name: str | None = None): # pylint: disable=unused-argument self._load_vilt(repo) move_aux_to_gpu('vqa') inputs = self.processor(image, question, return_tensors="pt") @@ -1111,7 +1111,7 @@ class VQA: self.loaded = repo devices.torch_gc() - def _pix(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument + def _pix(self, question: str, image: Image.Image, repo: str, model_name: str | None = None): # pylint: disable=unused-argument self._load_pix(repo) move_aux_to_gpu('vqa') if len(question) > 0: @@ -1144,7 +1144,7 @@ class VQA: register_aux('vqa', self.model) devices.torch_gc() - def _moondream(self, question: str, image: Image.Image, repo: str, model_name: str = None, thinking_mode: bool = False): + def _moondream(self, question: str, image: Image.Image, repo: str, model_name: str | None = None, thinking_mode: bool = False): debug(f'VQA caption: handler=moondream model_name="{model_name}" repo="{repo}" question="{question}" thinking_mode={thinking_mode}') self._load_moondream(repo) move_aux_to_gpu('vqa') @@ -1207,7 +1207,7 @@ class VQA: # When keep_thinking is False, just use the answer (reasoning is discarded) return response - def _load_florence(self, repo: str, revision: str = None): + def _load_florence(self, repo: str, revision: str | None = None): """Load Florence-2 model and processor.""" _get_imports = transformers.dynamic_module_utils.get_imports @@ -1247,7 +1247,7 @@ class VQA: self.loaded = cache_key devices.torch_gc() - def _florence(self, question: str, image: Image.Image, repo: str, revision: str = None, model_name: str = None): # pylint: disable=unused-argument + def _florence(self, question: str, image: Image.Image, repo: str, revision: str | None = None, model_name: str | None = None): # pylint: disable=unused-argument self._load_florence(repo, revision) move_aux_to_gpu('vqa') if question.startswith('<'): @@ -1306,7 +1306,7 @@ class VQA: self.loaded = repo devices.torch_gc() - def _sa2(self, question: str, image: Image.Image, repo: str, model_name: str = None): # pylint: disable=unused-argument + def _sa2(self, question: str, image: Image.Image, repo: str, model_name: str | None = None): # pylint: disable=unused-argument self._load_sa2(repo) move_aux_to_gpu('vqa') if question.startswith('<'): @@ -1325,7 +1325,18 @@ class VQA: response = return_dict["prediction"] # the text format answer return response - def caption(self, question: str = '', system_prompt: str = None, prompt: str = None, image: Image.Image = None, model_name: str = None, prefill: str = None, thinking_mode: bool = None, quiet: bool = False, generation_kwargs: dict = None) -> str: + def caption( + self, + question: str = "", + system_prompt: str | None = None, + prompt: str | None = None, + image: list[Image.Image] | Image.Image | dict | None = None, + model_name: str | None = None, + prefill: str | None = None, + thinking_mode: bool | None = None, + quiet: bool = False, + generation_kwargs: dict | None = None, + ) -> str: """ Main entry point for VQA captioning. Returns string answer. Detection data stored in self.last_detection_data for annotated image creation. @@ -1596,7 +1607,7 @@ def unload_model(): return get_instance().unload() -def load_model(model_name: str = None): +def load_model(model_name: str | None = None): return get_instance().load(model_name) diff --git a/modules/caption/vqa_detection.py b/modules/caption/vqa_detection.py index d0aadf970..683067250 100644 --- a/modules/caption/vqa_detection.py +++ b/modules/caption/vqa_detection.py @@ -38,7 +38,7 @@ def parse_points(result) -> list: return points -def parse_detections(result, label: str, max_objects: int = None) -> list: +def parse_detections(result, label: str, max_objects: int | None = None) -> list: """Parse and validate detection bboxes from model result. Args: @@ -74,7 +74,7 @@ def parse_detections(result, label: str, max_objects: int = None) -> list: return detections -def parse_florence_detections(response, image_size: tuple = None) -> list: +def parse_florence_detections(response, image_size: tuple | None = None) -> list: """Parse Florence-style detection response into standard detection format. Florence returns detection data in two possible formats: @@ -286,7 +286,7 @@ def calculate_eye_position(face_bbox: dict) -> tuple: return (eye_x, eye_y) -def draw_bounding_boxes(image: Image.Image, detections: list, points: list = None) -> Image.Image: +def draw_bounding_boxes(image: Image.Image, detections: list, points: list | None = None) -> Image.Image: """ Draw bounding boxes and/or points on an image. diff --git a/modules/caption/waifudiffusion.py b/modules/caption/waifudiffusion.py index 3e284db75..4fe455553 100644 --- a/modules/caption/waifudiffusion.py +++ b/modules/caption/waifudiffusion.py @@ -50,7 +50,7 @@ class WaifuDiffusionTagger: self.model_path = None self.image_size = 448 # Standard for WD models - def load(self, model_name: str = None): + def load(self, model_name: str | None = None): """Load the ONNX model and tags from HuggingFace.""" import huggingface_hub @@ -195,15 +195,15 @@ class WaifuDiffusionTagger: def predict( self, - image: Image.Image, - general_threshold: float = None, - character_threshold: float = None, - include_rating: bool = None, - exclude_tags: str = None, - max_tags: int = None, - sort_alpha: bool = None, - use_spaces: bool = None, - escape_brackets: bool = None, + image: Image.Image | list[Image.Image] | dict | None, + general_threshold: float | None = None, + character_threshold: float | None = None, + include_rating: bool | None = None, + exclude_tags: str | None = None, + max_tags: int | None = None, + sort_alpha: bool | None = None, + use_spaces: bool | None = None, + escape_brackets: bool | None = None, ) -> str: """Run inference and return formatted tag string. @@ -352,7 +352,7 @@ def refresh_models() -> list: return get_models() -def load_model(model_name: str = None) -> bool: +def load_model(model_name: str | None = None) -> bool: """Load the specified WaifuDiffusion model.""" return tagger.load(model_name) @@ -362,7 +362,7 @@ def unload_model(): tagger.unload() -def tag(image: Image.Image, model_name: str = None, **kwargs) -> str: +def tag(image: Image.Image, model_name: str | None = None, **kwargs) -> str: """Tag an image using WaifuDiffusion tagger. Args: