use bicubic interpolation for masks
instead of nearest by default to reduce pixelationpull/696/head
parent
17df392c25
commit
c19b4e92e2
|
|
@ -41,6 +41,6 @@ def get_word_mask(root, frame, word_mask):
|
||||||
preds = root.clipseg_model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
preds = root.clipseg_model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
|
||||||
|
|
||||||
mask = torch.sigmoid(preds[0][0]).unsqueeze(0).unsqueeze(0) # add batch, channels dims
|
mask = torch.sigmoid(preds[0][0]).unsqueeze(0).unsqueeze(0) # add batch, channels dims
|
||||||
resized_mask = interpolate(mask, size=(frame.size[1], frame.size[0])).squeeze() # rescale mask back to the target resolution
|
resized_mask = interpolate(mask, size=(frame.size[1], frame.size[0]), mode='bicubic').squeeze() # rescale mask back to the target resolution
|
||||||
numpy_array = resized_mask.multiply(255).to(dtype=torch.uint8,device='cpu').numpy()
|
numpy_array = resized_mask.multiply(255).to(dtype=torch.uint8,device='cpu').numpy()
|
||||||
return Image.fromarray(cv2.threshold(numpy_array, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1])
|
return Image.fromarray(cv2.threshold(numpy_array, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue