diff --git a/lama_cleaner/__init__.py b/lama_cleaner/__init__.py index a1ce1a1..82a0711 100644 --- a/lama_cleaner/__init__.py +++ b/lama_cleaner/__init__.py @@ -2,11 +2,12 @@ import os os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" -import warnings +import warnings # noqa: E402 -warnings.simplefilter("ignore", UserWarning) +warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") +warnings.filterwarnings("ignore", category=UserWarning, module="lama_cleaner") -from lama_cleaner.parse_args import parse_args +from lama_cleaner.parse_args import parse_args # noqa: E402 def entry_point(): diff --git a/mobile_sam/__init__.py b/mobile_sam/__init__.py index 7140702..6e2138b 100644 --- a/mobile_sam/__init__.py +++ b/mobile_sam/__init__.py @@ -4,16 +4,14 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from .build_sam import ( - build_sam, - build_sam_vit_h, - build_sam_vit_l, - build_sam_vit_b, - build_sam_vit_t, - sam_model_registry, -) -from .predictor import SamPredictor -from .automatic_mask_generator import SamAutomaticMaskGenerator +import warnings + +warnings.filterwarnings("ignore", category=UserWarning, module="mobile_sam") + +from .automatic_mask_generator import SamAutomaticMaskGenerator # noqa: E402 +from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h, build_sam_vit_l, # noqa: E402 + build_sam_vit_t, sam_model_registry) +from .predictor import SamPredictor # noqa: E402 __all__ = [ "build_sam", diff --git a/mobile_sam/modeling/tiny_vit_sam.py b/mobile_sam/modeling/tiny_vit_sam.py index 79205db..f67e3ba 100644 --- a/mobile_sam/modeling/tiny_vit_sam.py +++ b/mobile_sam/modeling/tiny_vit_sam.py @@ -8,7 +8,6 @@ # -------------------------------------------------------- import itertools -import warnings from typing import Tuple import torch @@ -19,8 +18,6 @@ from timm.models.layers import DropPath as TimmDropPath from timm.models.layers import to_2tuple, trunc_normal_ from timm.models.registry import register_model -warnings.simplefilter("ignore", category=UserWarning) - class Conv2d_BN(torch.nn.Sequential): def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,