Use device by reference for dino (#189)

pull/190/head
Andray 2024-02-07 02:24:41 +04:00 committed by GitHub
parent d80220ecd2
commit 5f69b959ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 7 additions and 7 deletions

View File

@ -6,7 +6,7 @@ import torch
from collections import OrderedDict
from modules import scripts, shared
from modules.devices import device, torch_gc, cpu
from modules import devices
import local_groundingdino
@ -96,7 +96,7 @@ def show_boxes(image_np, boxes, color=(255, 0, 0, 255), thickness=2, show_index=
def clear_dino_cache():
dino_model_cache.clear()
gc.collect()
torch_gc()
devices.torch_gc()
def load_dino_model(dino_checkpoint, dino_install_success):
@ -104,7 +104,7 @@ def load_dino_model(dino_checkpoint, dino_install_success):
if dino_checkpoint in dino_model_cache:
dino = dino_model_cache[dino_checkpoint]
if shared.cmd_opts.lowvram:
dino.to(device=device)
dino.to(device=devices.device)
else:
clear_dino_cache()
if dino_install_success:
@ -121,7 +121,7 @@ def load_dino_model(dino_checkpoint, dino_install_success):
dino_model_info[dino_checkpoint]["url"], dino_model_dir)
dino.load_state_dict(clean_state_dict(
checkpoint['model']), strict=False)
dino.to(device=device)
dino.to(device=devices.device)
dino_model_cache[dino_checkpoint] = dino
dino.eval()
return dino
@ -148,11 +148,11 @@ def get_grounding_output(model, image, caption, box_threshold):
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
image = image.to(device)
image = image.to(devices.device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
if shared.cmd_opts.lowvram:
model.to(cpu)
model.to(devices.cpu)
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"][0] # (nq, 4)
@ -183,5 +183,5 @@ def dino_predict_internal(input_image, dino_model_name, text_prompt, box_thresho
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
boxes_filt[i][2:] += boxes_filt[i][:2]
gc.collect()
torch_gc()
devices.torch_gc()
return boxes_filt, install_success