diff --git a/scripts/dino.py b/scripts/dino.py index a034196..11a4fc9 100644 --- a/scripts/dino.py +++ b/scripts/dino.py @@ -6,7 +6,7 @@ import torch from collections import OrderedDict from modules import scripts, shared -from modules.devices import device, torch_gc, cpu +from modules import devices import local_groundingdino @@ -96,7 +96,7 @@ def show_boxes(image_np, boxes, color=(255, 0, 0, 255), thickness=2, show_index= def clear_dino_cache(): dino_model_cache.clear() gc.collect() - torch_gc() + devices.torch_gc() def load_dino_model(dino_checkpoint, dino_install_success): @@ -104,7 +104,7 @@ def load_dino_model(dino_checkpoint, dino_install_success): if dino_checkpoint in dino_model_cache: dino = dino_model_cache[dino_checkpoint] if shared.cmd_opts.lowvram: - dino.to(device=device) + dino.to(device=devices.device) else: clear_dino_cache() if dino_install_success: @@ -121,7 +121,7 @@ def load_dino_model(dino_checkpoint, dino_install_success): dino_model_info[dino_checkpoint]["url"], dino_model_dir) dino.load_state_dict(clean_state_dict( checkpoint['model']), strict=False) - dino.to(device=device) + dino.to(device=devices.device) dino_model_cache[dino_checkpoint] = dino dino.eval() return dino @@ -148,11 +148,11 @@ def get_grounding_output(model, image, caption, box_threshold): caption = caption.strip() if not caption.endswith("."): caption = caption + "." - image = image.to(device) + image = image.to(devices.device) with torch.no_grad(): outputs = model(image[None], captions=[caption]) if shared.cmd_opts.lowvram: - model.to(cpu) + model.to(devices.cpu) logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256) boxes = outputs["pred_boxes"][0] # (nq, 4) @@ -183,5 +183,5 @@ def dino_predict_internal(input_image, dino_model_name, text_prompt, box_thresho boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 boxes_filt[i][2:] += boxes_filt[i][:2] gc.collect() - torch_gc() + devices.torch_gc() return boxes_filt, install_success