ImageReward removed, #216

pull/226/head
AlUlkesh 2023-07-25 08:51:27 +02:00
parent 6386e0c8f5
commit b984cdd169
5 changed files with 13 additions and 120 deletions

View File

@ -18,7 +18,6 @@ Please be aware that when scanning a directory for the first time, the png-cache
## Recent updates ## Recent updates
- "All"-tab showing all the images from all tabs combined - "All"-tab showing all the images from all tabs combined
- Image Reward scoring
- Size tooltip for thumbnails - Size tooltip for thumbnails
- Optimized images in the thumbnail interface - Optimized images in the thumbnail interface
- Send to ControlNet - Send to ControlNet

View File

@ -1,18 +1,4 @@
import launch import launch
import os
if not launch.is_installed("send2trash"): if not launch.is_installed("send2trash"):
launch.run_pip("install Send2Trash", "Send2Trash requirement for image browser") launch.run_pip("install Send2Trash", "Send2Trash requirement for image browser")
try:
import ImageReward
import datasets
import dill
import diffusers
import multiprocessing
import pyarrow
import xxhash
except (ImportError, ModuleNotFoundError) as e:
#print(e)
req_IR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "req_IR.txt")
launch.run_pip(f'install -r "{req_IR}" --no-deps image-reward', 'ImageReward requirement for image browser')

View File

@ -1,7 +0,0 @@
datasets
diffusers
dill
fairscale
multiprocess
pyarrow
xxhash

View File

@ -45,16 +45,6 @@ except ImportError:
print("Image Browser: send2trash is not installed. recycle bin cannot be used.") print("Image Browser: send2trash is not installed. recycle bin cannot be used.")
send2trash_installed = False send2trash_installed = False
try:
import ImageReward
import pyarrow._s3fs
pyarrow._s3fs.finalize_s3()
image_reward_installed = True
except ImportError as e:
print("Image Browser: ImageReward components are not installed, cannot be used.")
print(e)
image_reward_installed = False
# Force reload wib_db, as it doesn't get reloaded otherwise, if an extension update is started from webui # Force reload wib_db, as it doesn't get reloaded otherwise, if an extension update is started from webui
importlib.reload(wib_db) importlib.reload(wib_db)
@ -82,7 +72,6 @@ openoutpaint = False
controlnet = False controlnet = False
js_dummy_return = None js_dummy_return = None
log_file = os.path.join(scripts.basedir(), "image_browser.log") log_file = os.path.join(scripts.basedir(), "image_browser.log")
image_reward_model = None
db_version = wib_db.check() db_version = wib_db.check()
@ -509,7 +498,7 @@ def cache_exif(fileinfos):
wib_db.update_exif_data(conn, fi_info[0], allExif) wib_db.update_exif_data(conn, fi_info[0], allExif)
new_exif = new_exif + 1 new_exif = new_exif + 1
m = re.search("(?:aesthetic_score:|Score:) (\d+.\d+)", allExif) m = re.search("(?:aesthetic_score:|Score:) (\d+.\d+)", allExif, flags=re.IGNORECASE)
if m: if m:
aes_value = m.group(1) aes_value = m.group(1)
else: else:
@ -528,7 +517,7 @@ def cache_exif(fileinfos):
wib_db.update_exif_data_by_key(conn, fi_info[0], geninfo) wib_db.update_exif_data_by_key(conn, fi_info[0], geninfo)
new_exif = new_exif + 1 new_exif = new_exif + 1
m = re.search("(?:aesthetic_score:|Score:) (\d+.\d+)", geninfo) m = re.search("(?:aesthetic_score:|Score:) (\d+.\d+)", geninfo, flags=re.IGNORECASE)
if m: if m:
aes_value = m.group(1) aes_value = m.group(1)
else: else:
@ -729,7 +718,7 @@ def exif_search(needle, haystack, use_regex, case_sensitive):
found = True found = True
return found return found
def get_all_images(dir_name, sort_by, sort_order, keyword, tab_base_tag_box, img_path_depth, ranking_filter, ranking_filter_min, ranking_filter_max, aes_filter_min, aes_filter_max, score_type, exif_keyword, negative_prompt_search, use_regex, case_sensitive): def get_all_images(dir_name, sort_by, sort_order, keyword, tab_base_tag_box, img_path_depth, ranking_filter, ranking_filter_min, ranking_filter_max, aes_filter_min, aes_filter_max, exif_keyword, negative_prompt_search, use_regex, case_sensitive):
global current_depth global current_depth
logger.debug("get_all_images") logger.debug("get_all_images")
current_depth = 0 current_depth = 0
@ -798,7 +787,7 @@ def get_all_images(dir_name, sort_by, sort_order, keyword, tab_base_tag_box, img
except ValueError: except ValueError:
aes_filter_max_num = sys.float_info.max aes_filter_max_num = sys.float_info.max
fileinfos = wib_db.filter_aes(cursor, fileinfos, aes_filter_min_num, aes_filter_max_num, score_type) fileinfos = wib_db.filter_aes(cursor, fileinfos, aes_filter_min_num, aes_filter_max_num)
filenames = [finfo[0] for finfo in fileinfos] filenames = [finfo[0] for finfo in fileinfos]
if ranking_filter != "All": if ranking_filter != "All":
ranking_filter_min_num = 1 ranking_filter_min_num = 1
@ -856,7 +845,7 @@ def get_all_images(dir_name, sort_by, sort_order, keyword, tab_base_tag_box, img
sort_values[k] = match.group().strip() sort_values[k] = match.group().strip()
else: else:
sort_values[k] = "0" sort_values[k] = "0"
if sort_by == "aesthetic_score" or sort_by == "ImageRewardScore" or sort_by == "cfg scale": if sort_by == "aesthetic_score" or sort_by == "cfg scale":
sort_float = True sort_float = True
else: else:
sort_float = False sort_float = False
@ -924,17 +913,17 @@ def set_tooltip_info(image_list):
image_browser_img_info_json = json.dumps(image_browser_img_info) image_browser_img_info_json = json.dumps(image_browser_img_info)
return image_browser_img_info_json return image_browser_img_info_json
def get_image_page(img_path, page_index, filenames, keyword, sort_by, sort_order, tab_base_tag_box, img_path_depth, ranking_filter, ranking_filter_min, ranking_filter_max, aes_filter_min, aes_filter_max, score_type, exif_keyword, negative_prompt_search, use_regex, case_sensitive, image_reward_button): def get_image_page(img_path, page_index, filenames, keyword, sort_by, sort_order, tab_base_tag_box, img_path_depth, ranking_filter, ranking_filter_min, ranking_filter_max, aes_filter_min, aes_filter_max, exif_keyword, negative_prompt_search, use_regex, case_sensitive):
logger.debug("get_image_page") logger.debug("get_image_page")
if img_path == "": if img_path == "":
return [], page_index, [], "", "", "", 0, "", None, "", "[]", image_reward_button return [], page_index, [], "", "", "", 0, "", None, "", "[]"
# Set temp_dir from webui settings, so gradio uses it # Set temp_dir from webui settings, so gradio uses it
if shared.opts.temp_dir != "": if shared.opts.temp_dir != "":
tempfile.tempdir = shared.opts.temp_dir tempfile.tempdir = shared.opts.temp_dir
img_path, _ = pure_path(img_path) img_path, _ = pure_path(img_path)
filenames = get_all_images(img_path, sort_by, sort_order, keyword, tab_base_tag_box, img_path_depth, ranking_filter, ranking_filter_min, ranking_filter_max, aes_filter_min, aes_filter_max, score_type, exif_keyword, negative_prompt_search, use_regex, case_sensitive) filenames = get_all_images(img_path, sort_by, sort_order, keyword, tab_base_tag_box, img_path_depth, ranking_filter, ranking_filter_min, ranking_filter_max, aes_filter_min, aes_filter_max, exif_keyword, negative_prompt_search, use_regex, case_sensitive)
page_index = int(page_index) page_index = int(page_index)
length = len(filenames) length = len(filenames)
max_page_index = math.ceil(length / num_of_imgs_per_page) max_page_index = math.ceil(length / num_of_imgs_per_page)
@ -1090,30 +1079,6 @@ def update_ranking(img_file_name, ranking_current, ranking, img_file_info):
img_file_info = update_exif(img_file_name, "Ranking", ranking) img_file_info = update_exif(img_file_name, "Ranking", ranking)
return ranking, None, img_file_info return ranking, None, img_file_info
def generate_image_reward(filenames, turn_page_switch, aes_filter_min, aes_filter_max):
global image_reward_model
if image_reward_model is None:
image_reward_model = ImageReward.load("ImageReward-v1.0")
conn, cursor = wib_db.transaction_begin()
for filename in filenames:
saved_image_reward_score, saved_image_reward_prompt = wib_db.select_image_reward_score(cursor, filename)
if saved_image_reward_score is None and saved_image_reward_prompt is not None:
try:
with torch.no_grad():
image_reward_score = image_reward_model.score(saved_image_reward_prompt, filename)
image_reward_score = f"{image_reward_score:.2f}"
try:
logger.warning(f"Generated ImageRewardScore: {image_reward_score} for {filename}")
except UnicodeEncodeError:
pass
wib_db.update_image_reward_score(cursor, filename, image_reward_score)
if any(filename.endswith(ext) for ext in image_ext_list):
img_file_info = update_exif(filename, "ImageRewardScore", image_reward_score)
except UnidentifiedImageError as e:
logger.warning(f"UnidentifiedImageError: {e}")
wib_db.transaction_end(conn, cursor)
return -turn_page_switch, aes_filter_min, aes_filter_max
def create_tab(tab: ImageBrowserTab, current_gr_tab: gr.Tab): def create_tab(tab: ImageBrowserTab, current_gr_tab: gr.Tab):
global init, exif_cache, aes_cache, openoutpaint, controlnet, js_dummy_return global init, exif_cache, aes_cache, openoutpaint, controlnet, js_dummy_return
dir_name = None dir_name = None
@ -1211,7 +1176,7 @@ def create_tab(tab: ImageBrowserTab, current_gr_tab: gr.Tab):
with gr.Column(scale=1): with gr.Column(scale=1):
with gr.Row() as sort_panel: with gr.Row() as sort_panel:
sort_by = gr.Dropdown(value="date", choices=["path name", "date", "aesthetic_score", "ImageRewardScore", "random", "cfg scale", "steps", "seed", "sampler", "size", "model", "model hash", "ranking"], label="Sort by") sort_by = gr.Dropdown(value="date", choices=["path name", "date", "aesthetic_score", "random", "cfg scale", "steps", "seed", "sampler", "size", "model", "model hash", "ranking"], label="Sort by")
sort_order = ToolButton(value=down_symbol) sort_order = ToolButton(value=down_symbol)
with gr.Row() as filename_search_panel: with gr.Row() as filename_search_panel:
filename_keyword_search = gr.Textbox(value="", label="Filename keyword search") filename_keyword_search = gr.Textbox(value="", label="Filename keyword search")
@ -1233,11 +1198,6 @@ def create_tab(tab: ImageBrowserTab, current_gr_tab: gr.Tab):
with gr.Column(scale=4, min_width=20): with gr.Column(scale=4, min_width=20):
gr.Textbox(value="Choose Min-max to activate these controls", label="", interactive=False) gr.Textbox(value="Choose Min-max to activate these controls", label="", interactive=False)
with gr.Box() as aesthetic_score_filter_panel: with gr.Box() as aesthetic_score_filter_panel:
with gr.Row():
with gr.Column(scale=4, min_width=20):
score_type = gr.Dropdown(value=opts.image_browser_scoring_type, choices=["aesthetic_score", "ImageReward Score"], label="Scoring type", interactive=True)
with gr.Column(scale=2, min_width=20):
image_reward_button = gr.Button(value="Generate ImageReward Scores for all images", interactive=image_reward_installed, visible=False)
with gr.Row(): with gr.Row():
aes_filter_min = gr.Textbox(value="", label="Minimum score") aes_filter_min = gr.Textbox(value="", label="Minimum score")
aes_filter_max = gr.Textbox(value="", label="Maximum score") aes_filter_max = gr.Textbox(value="", label="Maximum score")
@ -1584,8 +1544,8 @@ def create_tab(tab: ImageBrowserTab, current_gr_tab: gr.Tab):
if standard_ui or others_dir: if standard_ui or others_dir:
turn_page_switch.change( turn_page_switch.change(
fn=get_image_page, fn=get_image_page,
inputs=[img_path, page_index, filenames, filename_keyword_search, sort_by, sort_order, tab_base_tag_box, img_path_depth, ranking_filter, ranking_filter_min, ranking_filter_max, aes_filter_min, aes_filter_max, score_type, exif_keyword_search, negative_prompt_search, use_regex, case_sensitive, image_reward_button], inputs=[img_path, page_index, filenames, filename_keyword_search, sort_by, sort_order, tab_base_tag_box, img_path_depth, ranking_filter, ranking_filter_min, ranking_filter_max, aes_filter_min, aes_filter_max, exif_keyword_search, negative_prompt_search, use_regex, case_sensitive],
outputs=[filenames, page_index, image_gallery, img_file_name, img_file_time, img_file_info, visible_img_num, warning_box, hidden, image_page_list, image_browser_img_info, image_reward_button], outputs=[filenames, page_index, image_gallery, img_file_name, img_file_time, img_file_info, visible_img_num, warning_box, hidden, image_page_list, image_browser_img_info],
show_progress=opts.image_browser_show_progress show_progress=opts.image_browser_show_progress
).then( ).then(
fn=None, fn=None,
@ -1615,12 +1575,6 @@ def create_tab(tab: ImageBrowserTab, current_gr_tab: gr.Tab):
_js="image_browser_controlnet_send_img2img", _js="image_browser_controlnet_send_img2img",
show_progress=opts.image_browser_show_progress show_progress=opts.image_browser_show_progress
) )
image_reward_button.click(
fn=generate_image_reward,
inputs=[filenames, turn_page_switch, aes_filter_min, aes_filter_max],
outputs=[turn_page_switch, aes_filter_min, aes_filter_max],
show_progress=True
)
def run_pnginfo(image, image_path, image_file_name): def run_pnginfo(image, image_path, image_file_name):
if image is None: if image is None:
@ -1737,7 +1691,6 @@ def on_ui_settings():
("image_browser_thumbnail_size", None, 200, "Size of the thumbnails (px)"), ("image_browser_thumbnail_size", None, 200, "Size of the thumbnails (px)"),
("image_browser_swipe", None, False, "Swipe left/right navigates to the next image"), ("image_browser_swipe", None, False, "Swipe left/right navigates to the next image"),
("image_browser_img_tooltips", None, True, "Enable thumbnail tooltips"), ("image_browser_img_tooltips", None, True, "Enable thumbnail tooltips"),
("image_browser_scoring_type", None, "aesthetic_score", "Default scoring type", gr.Dropdown, lambda: {"choices": ["aesthetic_score", "ImageReward Score"]}),
("image_browser_show_progress", None, True, "Show progress indicator"), ("image_browser_show_progress", None, True, "Show progress indicator"),
("image_browser_info_add", None, False, "Show Additional Generation Info"), ("image_browser_info_add", None, False, "Show Additional Generation Info"),
] ]

View File

@ -510,41 +510,6 @@ def update_ranking(file, ranking):
return return
def select_image_reward_score(cursor, file):
cursor.execute('''
SELECT value
FROM exif_data
WHERE file = ?
AND key = 'ImageRewardScore'
''', (file,))
image_reward_score = cursor.fetchone()
if image_reward_score is None:
return_image_reward_score = None
else:
(return_image_reward_score,) = image_reward_score
cursor.execute('''
SELECT value
FROM exif_data
WHERE file = ?
AND key = 'prompt'
''', (file,))
image_reward_prompt = cursor.fetchone()
if image_reward_prompt is None:
return_image_reward_prompt = None
else:
(return_image_reward_prompt,) = image_reward_prompt
return return_image_reward_score, return_image_reward_prompt
def update_image_reward_score(cursor, file, image_reward_score):
cursor.execute('''
INSERT OR REPLACE
INTO exif_data (file, key, value)
VALUES (?, ?, ?)
''', (file, "ImageRewardScore", image_reward_score))
return
def update_path_recorder(path, depth, path_display): def update_path_recorder(path, depth, path_display):
with sqlite3.connect(db_file, timeout=timeout) as conn: with sqlite3.connect(db_file, timeout=timeout) as conn:
cursor = conn.cursor() cursor = conn.cursor()
@ -802,11 +767,8 @@ def fill_work_files(cursor, fileinfos):
return return
def filter_aes(cursor, fileinfos, aes_filter_min_num, aes_filter_max_num, score_type): def filter_aes(cursor, fileinfos, aes_filter_min_num, aes_filter_max_num):
if score_type == "aesthetic_score":
key = "aesthetic_score" key = "aesthetic_score"
else:
key = "ImageRewardScore"
cursor.execute(''' cursor.execute('''
DELETE DELETE