diff --git a/scripts/deforum_helpers/word_masking.py b/scripts/deforum_helpers/word_masking.py index 50752d19..342a7c76 100644 --- a/scripts/deforum_helpers/word_masking.py +++ b/scripts/deforum_helpers/word_masking.py @@ -41,6 +41,6 @@ def get_word_mask(root, frame, word_mask): preds = root.clipseg_model(img.repeat(len(word_masks),1,1,1), word_masks)[0] mask = torch.sigmoid(preds[0][0]).unsqueeze(0).unsqueeze(0) # add batch, channels dims - resized_mask = interpolate(mask, size=(frame.size[1], frame.size[0])).squeeze() # rescale mask back to the target resolution + resized_mask = interpolate(mask, size=(frame.size[1], frame.size[0]), mode='bicubic').squeeze() # rescale mask back to the target resolution numpy_array = resized_mask.multiply(255).to(dtype=torch.uint8,device='cpu').numpy() return Image.fromarray(cv2.threshold(numpy_array, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1])