Fix type (#2691)
parent
b379b90ab4
commit
2091b6fb21
|
|
@ -83,10 +83,10 @@ downloads = {
|
|||
|
||||
|
||||
clip_vision_h_uc = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'clip_vision_h_uc.data')
|
||||
clip_vision_h_uc = torch.load(clip_vision_h_uc, map_location=torch.device(devices.get_device_for("controlnet") if torch.cuda.is_available() else 'cpu'))['uc']
|
||||
clip_vision_h_uc = torch.load(clip_vision_h_uc, map_location=devices.get_device_for("controlnet") if torch.cuda.is_available() else torch.device('cpu'))['uc']
|
||||
|
||||
clip_vision_vith_uc = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'clip_vision_vith_uc.data')
|
||||
clip_vision_vith_uc = torch.load(clip_vision_vith_uc, map_location=torch.device(devices.get_device_for("controlnet") if torch.cuda.is_available() else 'cpu'))['uc']
|
||||
clip_vision_vith_uc = torch.load(clip_vision_vith_uc, map_location=devices.get_device_for("controlnet") if torch.cuda.is_available() else torch.device('cpu'))['uc']
|
||||
|
||||
|
||||
class ClipVisionDetector:
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class CrossEntropy2d(nn.Module):
|
|||
self.ignore_label = ignore_label
|
||||
self.weights = weights
|
||||
if self.weights is not None:
|
||||
device = torch.device(devices.get_device_for("controlnet") if torch.cuda.is_available() else 'cpu')
|
||||
device = devices.get_device_for("controlnet") if torch.cuda.is_available() else torch.device('cpu')
|
||||
self.weights = torch.FloatTensor(constant_weights[weights]).to(device)
|
||||
|
||||
def forward(self, predict, target):
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ def run(input_path, output_path, model_path, model_type="dpt_beit_large_512", op
|
|||
print("Initialize")
|
||||
|
||||
# select device
|
||||
device = torch.device(devices.get_device_for("controlnet") if torch.cuda.is_available() else "cpu")
|
||||
device = devices.get_device_for("controlnet") if torch.cuda.is_available() else torch.device("cpu")
|
||||
print("Device: %s" % device)
|
||||
|
||||
model, transform, net_w, net_h = load_model(device, model_path, model_type, optimize, height, square)
|
||||
|
|
|
|||
Loading…
Reference in New Issue