enable GPU support

pull/101/head
Somdev Sangwan 2023-06-22 16:17:58 +05:30 committed by GitHub
parent 124f7d7acf
commit a116b2cb71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 4 deletions

View File

@ -17,7 +17,7 @@ from modules.face_restoration import FaceRestoration, restore_faces
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from scripts.roop_logging import logger from scripts.roop_logging import logger
providers = ["CPUExecutionProvider"] providers = onnxruntime.get_available_providers()
@dataclass @dataclass
@ -28,8 +28,6 @@ class UpscaleOptions:
face_restorer: FaceRestoration = None face_restorer: FaceRestoration = None
restorer_visibility: float = 0.5 restorer_visibility: float = 0.5
ANALYSIS_MODEL = insightface.app.FaceAnalysis(name="buffalo_l", providers=providers)
FS_MODEL = None FS_MODEL = None
CURRENT_FS_MODEL_PATH = None CURRENT_FS_MODEL_PATH = None
@ -75,7 +73,7 @@ def upscale_image(image: Image, upscale_options: UpscaleOptions):
def get_face_single(img_data: np.ndarray, face_index=0, det_size=(640, 640)): def get_face_single(img_data: np.ndarray, face_index=0, det_size=(640, 640)):
face_analyser = copy.deepcopy(ANALYSIS_MODEL) face_analyser = insightface.app.FaceAnalysis(name="buffalo_l", providers=providers)
face_analyser.prepare(ctx_id=0, det_size=det_size) face_analyser.prepare(ctx_id=0, det_size=det_size)
face = face_analyser.get(img_data) face = face_analyser.get(img_data)