pull/2708/head
Chenlei Hu 2024-03-15 02:32:35 +00:00 committed by GitHub
parent b379b90ab4
commit 2091b6fb21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 4 additions and 4 deletions

View File

@ -83,10 +83,10 @@ downloads = {
clip_vision_h_uc = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'clip_vision_h_uc.data')
clip_vision_h_uc = torch.load(clip_vision_h_uc, map_location=torch.device(devices.get_device_for("controlnet") if torch.cuda.is_available() else 'cpu'))['uc']
clip_vision_h_uc = torch.load(clip_vision_h_uc, map_location=devices.get_device_for("controlnet") if torch.cuda.is_available() else torch.device('cpu'))['uc']
clip_vision_vith_uc = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'clip_vision_vith_uc.data')
clip_vision_vith_uc = torch.load(clip_vision_vith_uc, map_location=torch.device(devices.get_device_for("controlnet") if torch.cuda.is_available() else 'cpu'))['uc']
clip_vision_vith_uc = torch.load(clip_vision_vith_uc, map_location=devices.get_device_for("controlnet") if torch.cuda.is_available() else torch.device('cpu'))['uc']
class ClipVisionDetector:

View File

@ -18,7 +18,7 @@ class CrossEntropy2d(nn.Module):
self.ignore_label = ignore_label
self.weights = weights
if self.weights is not None:
device = torch.device(devices.get_device_for("controlnet") if torch.cuda.is_available() else 'cpu')
device = devices.get_device_for("controlnet") if torch.cuda.is_available() else torch.device('cpu')
self.weights = torch.FloatTensor(constant_weights[weights]).to(device)
def forward(self, predict, target):

View File

@ -122,7 +122,7 @@ def run(input_path, output_path, model_path, model_type="dpt_beit_large_512", op
print("Initialize")
# select device
device = torch.device(devices.get_device_for("controlnet") if torch.cuda.is_available() else "cpu")
device = devices.get_device_for("controlnet") if torch.cuda.is_available() else torch.device("cpu")
print("Device: %s" % device)
model, transform, net_w, net_h = load_model(device, model_path, model_type, optimize, height, square)