Use device by reference for dino (#189)
parent
d80220ecd2
commit
5f69b959ee
|
|
@ -6,7 +6,7 @@ import torch
|
|||
from collections import OrderedDict
|
||||
|
||||
from modules import scripts, shared
|
||||
from modules.devices import device, torch_gc, cpu
|
||||
from modules import devices
|
||||
import local_groundingdino
|
||||
|
||||
|
||||
|
|
@ -96,7 +96,7 @@ def show_boxes(image_np, boxes, color=(255, 0, 0, 255), thickness=2, show_index=
|
|||
def clear_dino_cache():
|
||||
dino_model_cache.clear()
|
||||
gc.collect()
|
||||
torch_gc()
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
def load_dino_model(dino_checkpoint, dino_install_success):
|
||||
|
|
@ -104,7 +104,7 @@ def load_dino_model(dino_checkpoint, dino_install_success):
|
|||
if dino_checkpoint in dino_model_cache:
|
||||
dino = dino_model_cache[dino_checkpoint]
|
||||
if shared.cmd_opts.lowvram:
|
||||
dino.to(device=device)
|
||||
dino.to(device=devices.device)
|
||||
else:
|
||||
clear_dino_cache()
|
||||
if dino_install_success:
|
||||
|
|
@ -121,7 +121,7 @@ def load_dino_model(dino_checkpoint, dino_install_success):
|
|||
dino_model_info[dino_checkpoint]["url"], dino_model_dir)
|
||||
dino.load_state_dict(clean_state_dict(
|
||||
checkpoint['model']), strict=False)
|
||||
dino.to(device=device)
|
||||
dino.to(device=devices.device)
|
||||
dino_model_cache[dino_checkpoint] = dino
|
||||
dino.eval()
|
||||
return dino
|
||||
|
|
@ -148,11 +148,11 @@ def get_grounding_output(model, image, caption, box_threshold):
|
|||
caption = caption.strip()
|
||||
if not caption.endswith("."):
|
||||
caption = caption + "."
|
||||
image = image.to(device)
|
||||
image = image.to(devices.device)
|
||||
with torch.no_grad():
|
||||
outputs = model(image[None], captions=[caption])
|
||||
if shared.cmd_opts.lowvram:
|
||||
model.to(cpu)
|
||||
model.to(devices.cpu)
|
||||
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
|
||||
boxes = outputs["pred_boxes"][0] # (nq, 4)
|
||||
|
||||
|
|
@ -183,5 +183,5 @@ def dino_predict_internal(input_image, dino_model_name, text_prompt, box_thresho
|
|||
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
|
||||
boxes_filt[i][2:] += boxes_filt[i][:2]
|
||||
gc.collect()
|
||||
torch_gc()
|
||||
devices.torch_gc()
|
||||
return boxes_filt, install_success
|
||||
|
|
|
|||
Loading…
Reference in New Issue