ImageReward reactivated. Fix #179

pull/190/head
AlUlkesh 2023-06-02 00:16:29 +02:00
parent 75af6d0c32
commit ba9bcee144
3 changed files with 19 additions and 16 deletions

View File

@ -1,8 +1,9 @@
import launch
import os
if not launch.is_installed("send2trash"):
launch.run_pip("install Send2Trash", "Send2Trash requirement for image browser")
# temporarily deactivated
#if not launch.is_installed("ImageReward"):
#launch.run_pip("install image-reward", "ImageReward requirement for image browser")
if not launch.is_installed("ImageReward"):
req_IR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "req_IR.txt")
launch.run_pip(f'install -r "{req_IR}" --no-deps image-reward', 'ImageReward requirement for image browser')

1
req_IR.txt Normal file
View File

@ -0,0 +1 @@
fairscale

View File

@ -47,9 +47,7 @@ except ImportError:
try:
import ImageReward
# temporarily deactivated
# image_reward_installed = True
image_reward_installed = False
image_reward_installed = True
except ImportError:
print("Image Browser: ImageReward is not installed, cannot be used.")
image_reward_installed = False
@ -1072,16 +1070,19 @@ def generate_image_reward(filenames, turn_page_switch, aes_filter_min, aes_filte
for filename in filenames:
saved_image_reward_score, saved_image_reward_prompt = wib_db.select_image_reward_score(cursor, filename)
if saved_image_reward_score is None and saved_image_reward_prompt is not None:
with torch.no_grad():
image_reward_score = image_reward_model.score(saved_image_reward_prompt, filename)
image_reward_score = f"{image_reward_score:.2f}"
try:
logger.warning(f"Generated ImageRewardScore: {image_reward_score} for {filename}")
except UnicodeEncodeError:
pass
wib_db.update_image_reward_score(cursor, filename, image_reward_score)
if any(filename.endswith(ext) for ext in image_ext_list):
img_file_info = update_exif(filename, "ImageRewardScore", image_reward_score)
with torch.no_grad():
image_reward_score = image_reward_model.score(saved_image_reward_prompt, filename)
image_reward_score = f"{image_reward_score:.2f}"
try:
logger.warning(f"Generated ImageRewardScore: {image_reward_score} for {filename}")
except UnicodeEncodeError:
pass
wib_db.update_image_reward_score(cursor, filename, image_reward_score)
if any(filename.endswith(ext) for ext in image_ext_list):
img_file_info = update_exif(filename, "ImageRewardScore", image_reward_score)
except UnidentifiedImageError as e:
logger.warning(f"UnidentifiedImageError: {e}")
wib_db.transaction_end(conn, cursor)
return -turn_page_switch, aes_filter_min, aes_filter_max
@ -1635,7 +1636,7 @@ def on_ui_tabs():
create_tab(tab, current_gr_tab)
gr.Checkbox(value=opts.image_browser_preload, elem_id="image_browser_preload", visible=False)
gr.Textbox(",".join( [tab.base_tag for tab in tabs_list] ), elem_id="image_browser_tab_base_tags_list", visible=False)
gr.Checkbox(value=opts.image_browser_swipe, elem_id=f"image_browser_swipe")
gr.Checkbox(value=opts.image_browser_swipe, elem_id=f"image_browser_swipe", visible=False)
javascript_level_value, (javascript_level, javascript_level_text) = debug_levels(arg_level="javascript")
level_value, (level, level_text) = debug_levels(arg_text=opts.image_browser_debug_level)