diff --git a/XTI_hijack.py b/XTI_hijack.py index ec08494..1dbc263 100644 --- a/XTI_hijack.py +++ b/XTI_hijack.py @@ -1,11 +1,7 @@ import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.ipex_interop import init_ipex + +init_ipex() from typing import Union, List, Optional, Dict, Any, Tuple from diffusers.models.unet_2d_condition import UNet2DConditionOutput diff --git a/fine_tune.py b/fine_tune.py index be61b3d..982dc8a 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -11,15 +11,10 @@ import toml from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index be43847..a207ad5 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -66,15 +66,10 @@ import diffusers import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, diff --git a/library/ipex_interop.py b/library/ipex_interop.py new file mode 100644 index 0000000..6fe320c --- /dev/null +++ b/library/ipex_interop.py @@ -0,0 +1,24 @@ +import torch + + +def init_ipex(): + """ + Try to import `intel_extension_for_pytorch`, and apply + the hijacks using `library.ipex.ipex_init`. + + If IPEX is not installed, this function does nothing. + """ + try: + import intel_extension_for_pytorch as ipex # noqa + except ImportError: + return + + try: + from library.ipex import ipex_init + + if torch.xpu.is_available(): + is_initialized, error_message = ipex_init() + if not is_initialized: + print("failed to initialize ipex:", error_message) + except Exception as e: + print("failed to initialize ipex:", e) diff --git a/library/model_util.py b/library/model_util.py index 1f40ce3..4361b49 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -5,15 +5,9 @@ import math import os import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - - ipex_init() -except Exception: - pass +init_ipex() import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index ab53998..0db9e34 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -18,15 +18,10 @@ import diffusers import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 45b9edd..15a7067 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -9,13 +9,11 @@ import random from einops import repeat import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from tqdm import tqdm from transformers import CLIPTokenizer from diffusers import EulerDiscreteScheduler diff --git a/sdxl_train.py b/sdxl_train.py index b4ce277..a3f6f3a 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -11,15 +11,10 @@ import toml from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import sdxl_model_util diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 4436dd3..7a88feb 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -14,13 +14,11 @@ import toml from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed import accelerate diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 6ae5377..b94bf5c 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -11,13 +11,11 @@ import toml from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel diff --git a/sdxl_train_network.py b/sdxl_train_network.py index d810ce7..5d36328 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,15 +1,10 @@ import argparse import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from library import sdxl_model_util, sdxl_train_util, train_util import train_network diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index f8a1d7b..df39371 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -3,13 +3,9 @@ import os import regex import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.ipex_interop import init_ipex + +init_ipex() import open_clip from library import sdxl_model_util, sdxl_train_util, train_util diff --git a/train_controlnet.py b/train_controlnet.py index cc0eaab..7b0b2bb 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -12,15 +12,10 @@ import toml from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel diff --git a/train_db.py b/train_db.py index 14d9dff..888cad2 100644 --- a/train_db.py +++ b/train_db.py @@ -12,15 +12,10 @@ import toml from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler diff --git a/train_network.py b/train_network.py index ef7d419..7529aac 100644 --- a/train_network.py +++ b/train_network.py @@ -14,15 +14,10 @@ from tqdm import tqdm import torch from torch.nn.parallel import DistributedDataParallel as DDP -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import model_util diff --git a/train_textual_inversion.py b/train_textual_inversion.py index f1cf6fb..441c1e0 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -8,15 +8,10 @@ import toml from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 71b4354..7046a48 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -8,13 +8,11 @@ from multiprocessing import Value from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from accelerate.utils import set_seed import diffusers from diffusers import DDPMScheduler