diff --git a/scripts/td_abg.py b/scripts/td_abg.py index 7ae90e9..417da63 100644 --- a/scripts/td_abg.py +++ b/scripts/td_abg.py @@ -98,6 +98,8 @@ def get_foreground(img, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_r print(mask.shape) else: mask = get_mask(img) + mask = (mask * 255).astype(np.uint8) + mask = mask.repeat(3, axis=2) mask = refinement(img, mask, fast, psp_L) mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB) @@ -132,7 +134,7 @@ def get_foreground(img, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_r img_df.loc[img_df['bg_cls'] == 0, ['a']] = 0 img_df.loc[img_df['bg_cls'] != 0, ['a']] = 255 img = df2rgba(img_df) - + if cascadePSP_enabled == True and td_abg_enabled == False: if sa_enabled == True: mask = get_sa_mask(img, query, model_name, predicted_iou_threshold, stability_score_threshold, clip_threshold) @@ -149,11 +151,7 @@ def get_foreground(img, td_abg_enabled, h_split, v_split, n_cluster, alpha, th_r if cascadePSP_enabled == False and td_abg_enabled == False: mask, img = rmbg_fn(img) - + + mask = img[:, :, 3] return mask, img - - - - -