diff --git a/scripts/main.py b/scripts/main.py index a0bcbd5..b4971b3 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -3,17 +3,18 @@ # Description : This script translates Stable diffusion prompt from one of the 50 languages supported by MBART # It uses MBartTranslator class that provides a simple interface for translating text using the MBart language model. -import modules.scripts as scripts -import gradio as gr -from modules.shared import opts - -from transformers import MBart50TokenizerFast, MBartForConditionalGeneration -import re import os +import re + +import gradio as gr +from transformers import MBart50TokenizerFast, MBartForConditionalGeneration + +import modules.scripts as scripts # The directory to store the models cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models') + class MBartTranslator: """MBartTranslator class provides a simple interface for translating text using the MBart language model. @@ -27,80 +28,12 @@ class MBartTranslator: def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None): - self.supported_languages = [ - "ar_AR", - "de_DE", - "en_XX", - "es_XX", - "fr_XX", - "hi_IN", - "it_IT", - "ja_XX", - "ko_XX", - "pt_XX", - "ru_RU", - "zh_XX", - "af_ZA", - "bn_BD", - "bs_XX", - "ca_XX", - "cs_CZ", - "da_XX", - "el_GR", - "et_EE", - "fa_IR", - "fi_FI", - "gu_IN", - "he_IL", - "hi_XX", - "hr_HR", - "hu_HU", - "id_ID", - "is_IS", - "ja_XX", - "jv_XX", - "ka_GE", - "kk_XX", - "km_KH", - "kn_IN", - "ko_KR", - "lo_LA", - "lt_LT", - "lv_LV", - "mk_MK", - "ml_IN", - "mr_IN", - "ms_MY", - "ne_NP", - "nl_XX", - "no_XX", - "pl_XX", - "ro_RO", - "si_LK", - "sk_SK", - "sl_SI", - "sq_AL", - "sr_XX", - "sv_XX", - "sw_TZ", - "ta_IN", - "te_IN", - "th_TH", - "tl_PH", - "tr_TR", - "uk_UA", - "ur_PK", - "vi_VN", - "war_PH", - "yue_XX", - "zh_CN", - "zh_TW", - ] print("Building translator") print("Loading generator (this may take few minutes the first time as I need to download the model)") self.model = MBartForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir) print("Loading tokenizer") - self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang, cache_dir=cache_dir) + self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang, + cache_dir=cache_dir) print("Translator is ready") def translate(self, text: str, input_language: str, output_language: str) -> str: @@ -114,10 +47,12 @@ class MBartTranslator: Returns: str: The translated text. """ - if input_language not in self.supported_languages: - raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}") - if output_language not in self.supported_languages: - raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}") + if input_language not in self.tokenizer.lang_code_to_id: + raise ValueError( + f"Input language not supported. Supported languages: {self.tokenizer.lang_code_to_id}") + if output_language not in self.tokenizer.lang_code_to_id: + raise ValueError( + f"Output language not supported. Supported languages: {self.tokenizer.lang_code_to_id}") self.tokenizer.src_lang = input_language encoded_input = self.tokenizer(text, return_tensors="pt") @@ -129,29 +64,6 @@ class MBartTranslator: return translated_text[0] - -class LanguageOption: - """ - A class representing a language option in a language selector. - - Attributes: - label (str): The display label for the language option. - language_code (str): The ISO 639-1 language code for the language option. - """ - - def __init__(self, label, language_code): - """ - Initializes a new LanguageOption instance. - - Args: - label (str): The display label for the language option. - language_code (str): The ISO 639-1 language code for the language option. - """ - self.label = label - self.language_code = language_code - - - # This is a list of LanguageOption objects that represent the various language options available. # Each LanguageOption object contains a label that represents the display name of the language and # a language code that represents the code for the language that will be used by the translation model. @@ -160,75 +72,61 @@ class LanguageOption: # For example, "en_XX" represents English language and "fr_FR" represents French language specific to France. # These LanguageOption objects will be used to display the language options to the user and to retrieve the # corresponding language code when the user selects a language. -language_options = [ - LanguageOption("Arabic", "ar_AR"), - LanguageOption("Deutsch", "de_DE"), - LanguageOption("English", "en_XX"), - LanguageOption("Spanish", "es_XX"), - LanguageOption("French", "fr_XX"), - LanguageOption("Hindi", "hi_IN"), - LanguageOption("Italian", "it_IT"), - LanguageOption("Japanese", "ja_XX"), - LanguageOption("Korean", "ko_XX"), - LanguageOption("Portuguese", "pt_XX"), - LanguageOption("Russian", "ru_RU"), - LanguageOption("Chinese", "zh_XX"), - LanguageOption("Afrikaans", "af_ZA"), - LanguageOption("Bengali", "bn_BD"), - LanguageOption("Bosnian", "bs_XX"), - LanguageOption("Catalan", "ca_XX"), - LanguageOption("Czech", "cs_CZ"), - LanguageOption("Danish", "da_XX"), - LanguageOption("Greek", "el_GR"), - LanguageOption("Estonian", "et_EE"), - LanguageOption("Persian", "fa_IR"), - LanguageOption("Finnish", "fi_FI"), - LanguageOption("Gujarati", "gu_IN"), - LanguageOption("Hebrew", "he_IL"), - LanguageOption("Croatian", "hr_HR"), - LanguageOption("Hungarian", "hu_HU"), - LanguageOption("Indonesian", "id_ID"), - LanguageOption("Icelandic", "is_IS"), - LanguageOption("Javanese", "jv_XX"), - LanguageOption("Georgian", "ka_GE"), - LanguageOption("Kazakh", "kk_XX"), - LanguageOption("Khmer", "km_KH"), - LanguageOption("Kannada", "kn_IN"), - LanguageOption("Korean", "ko_KR"), - LanguageOption("Lao", "lo_LA"), - LanguageOption("Lithuanian", "lt_LT"), - LanguageOption("Latvian", "lv_LV"), - LanguageOption("Macedonian", "mk_MK"), - LanguageOption("Malayalam", "ml_IN"), - LanguageOption("Marathi", "mr_IN"), - LanguageOption("Malay", "ms_MY"), - LanguageOption("Nepali", "ne_NP"), - LanguageOption("Dutch", "nl_XX"), - LanguageOption("Norwegian", "no_XX"), - LanguageOption("Polish", "pl_XX"), - LanguageOption("Romanian", "ro_RO"), - LanguageOption("Sinhala", "si_LK"), - LanguageOption("Slovak", "sk_SK"), - LanguageOption("Slovenian", "sl_SI"), - LanguageOption("Albanian", "sq_AL"), - LanguageOption("Serbian", "sr_XX"), - LanguageOption("Swedish", "sv_XX"), - LanguageOption("Swahili", "sw_TZ"), - LanguageOption("Tamil", "ta_IN"), - LanguageOption("Telugu", "te_IN"), - LanguageOption("Tamil ", "ta_IN"), - LanguageOption("Telugu", "te_IN"), - LanguageOption("Thai", "th_TH"), - LanguageOption("Tagalog", "tl_PH"), - LanguageOption("Turkish", "tr_TR"), - LanguageOption("Ukrainian", "uk_UA"), - LanguageOption("Urdu", "ur_PK"), - LanguageOption("Vietnamese", "vi_VN"), - LanguageOption("Waray", "war_PH"), - LanguageOption("Cantonese", "yue_XX"), - LanguageOption("Chinese", "zh_CN"), - LanguageOption("Chinese", "zh_TW"), -] +language_options = { + 'Afrikaans': 'af_ZA', + 'Arabic': 'ar_AR', + 'Azerbaijani (Azerbaijan)': 'az_AZ', + 'Bengali (India)': 'bn_IN', + 'Burmese (Myanmar)': 'my_MM', + 'Chinese': 'zh_CN', + 'Croatian': 'hr_HR', + 'Czech': 'cs_CZ', + 'Deutsch': 'de_DE', + "Dutch": "nl_XX", + 'English': 'en_XX', + 'Estonian': 'et_EE', + 'Finnish': 'fi_FI', + 'French': 'fr_XX', + 'Galician (Spain)': 'gl_ES', + 'Georgian': 'ka_GE', + 'Gujarati': 'gu_IN', + 'Hebrew': 'he_IL', + 'Hindi': 'hi_IN', + 'Indonesian': 'id_ID', + 'Italian': 'it_IT', + 'Japanese': 'ja_XX', + 'Kazakh (Kazakhstan)': 'kk_KZ', + 'Khmer': 'km_KH', + 'Korean': 'ko_KR', + 'Latvian': 'lv_LV', + 'Lithuanian': 'lt_LT', + 'Macedonian': 'mk_MK', + 'Malayalam': 'ml_IN', + 'Marathi': 'mr_IN', + 'Mongolian (Mongolia)': 'mn_MN', + 'Nepali': 'ne_NP', + 'Pashto (Afghanistan)': 'ps_AF', + 'Persian': 'fa_IR', + 'Polish (Poland)': 'pl_PL', + 'Portuguese': 'pt_XX', + 'Romanian': 'ro_RO', + 'Russian': 'ru_RU', + 'Sinhala': 'si_LK', + 'Slovenian': 'sl_SI', + 'Spanish': 'es_XX', + 'Swahili (Kenya)': 'sw_KE', + 'Swedish (Sweden)': 'sv_SE', + 'Tamil': 'ta_IN', + 'Telugu': 'te_IN', + "Tagalog": "tl_XX", + 'Thai': 'th_TH', + 'Turkish': 'tr_TR', + 'Ukrainian': 'uk_UA', + 'Urdu': 'ur_PK', + 'Vietnamese': 'vi_VN', + 'Xhosa (South Africa)': 'xh_ZA' +} + def remove_unnecessary_spaces(text): """Removes unnecessary spaces between characters.""" @@ -236,28 +134,30 @@ def remove_unnecessary_spaces(text): replacement = r")++" return re.sub(pattern, replacement, text) + def correct_translation_format(original_text, translated_text): original_parts = original_text.split('++') translated_parts = translated_text.split('++') - + corrected_parts = [] for i, original_part in enumerate(original_parts): translated_part = translated_parts[i] - + original_plus_count = original_part.count('+') translated_plus_count = translated_part.count('+') plus_difference = translated_plus_count - original_plus_count - + if plus_difference > 0: translated_part = translated_part.replace('+' * plus_difference, '', 1) elif plus_difference < 0: translated_part += '+' * abs(plus_difference) - + corrected_parts.append(translated_part) - + corrected_text = '++'.join(corrected_parts) return corrected_text + def extract_plus_positions(text): """ Given a string of text, extracts the positions of all sequences of one or more '+' characters. @@ -289,7 +189,7 @@ def extract_plus_positions(text): positions.append([j, last_match_end, last_match_end - j]) last_match_end = match.end() - + # If the final match extends to the end of the string, add its position to the output list if last_match_end is not None and last_match_end == len(text): j = last_match_end - 1 @@ -314,11 +214,11 @@ def match_pluses(original_text, translated_text): - output (str): the translated text with '+' characters replaced by those in the original text """ in_positions = extract_plus_positions(original_text) - out_positions = extract_plus_positions(translated_text) - + out_positions = extract_plus_positions(translated_text) + out_vals = [] out_current_pos = 0 - + if len(in_positions) == len(out_positions): # Iterate through the positions and replace the sequences of '+' characters in the translated text # with those in the original text @@ -326,31 +226,33 @@ def match_pluses(original_text, translated_text): out_vals.append(translated_text[out_current_pos:out_[0]]) out_vals.append(original_text[in_[0]:in_[1]]) out_current_pos = out_[1] - + # Check that the number of '+' characters in the original and translated sequences is the same if in_[2] != out_[2]: print("detected different + count") # Add any remaining text from the translated string to the output out_vals.append(translated_text[out_current_pos:]) - + # Join the output values into a single string output = "".join(out_vals) return output + def post_process_prompt(original, translated): """Applies post-processing to the translated prompt such as removing unnecessary spaces and extra plus signs.""" clean_prompt = remove_unnecessary_spaces(translated) clean_prompt = match_pluses(original, clean_prompt) - #clean_prompt = remove_extra_plus(clean_prompt) - return clean_prompt + # clean_prompt = remove_extra_plus(clean_prompt) + return clean_prompt + class Script(scripts.Script): def __init__(self) -> None: """Initializes the Script class and sets the default value for enable_translation attribute.""" super().__init__() - self.enable_translation=False - self.is_negative_translate_active=False + self.enable_translation = False + self.is_negative_translate_active = False def title(self): """Returns the title of the script.""" @@ -359,11 +261,11 @@ class Script(scripts.Script): def show(self, is_img2img): """Returns the visibility status of the script in the interface.""" return scripts.AlwaysVisible - + def set_active(self, active): """Sets the is_active attribute and initializes the translator object if not already created. Also, sets the visibility of the language dropdown to True.""" - self.is_active=active + self.is_active = active if not hasattr(self, "translator"): self.translator = MBartTranslator() if self.is_active: @@ -374,19 +276,17 @@ class Script(scripts.Script): def set_negative_translate_active(self, negative_translate_active): """Sets the is_active attribute and initializes the translator object if not already created. Also, sets the visibility of the language dropdown to True.""" - self.is_negative_translate_active=negative_translate_active - + self.is_negative_translate_active = negative_translate_active def ui(self, is_img2img): """Sets up the user interface of the script.""" - self.is_active=False - self.current_axis_options = [x for x in language_options] + self.is_active = False with gr.Row(): with gr.Column(scale=19): - with gr.Accordion("Prompt Translator",open=False): - with gr.Accordion("Help",open=False): - md = gr.Markdown(""" + with gr.Accordion("Prompt Translator", open=False): + with gr.Accordion("Help", open=False): + gr.Markdown(""" # Description This script translates your prompt from another language to english before generating the image allowing you to write prompts in your native language. # How to use @@ -398,29 +298,30 @@ class Script(scripts.Script): with gr.Column(): self.enable_translation = gr.Checkbox(label="Enable translation") with gr.Column() as options: - self.options=options + self.options = options self.translate_negative_prompt = gr.Checkbox(label="Translate negative prompt") - self.enable_translation.value=False + self.enable_translation.value = False self.language = gr.Dropdown( - label="Source language", - choices=[x.label for x in self.current_axis_options], - value="Français", - type="index", - elem_id=self.elem_id("x_type") - ) - self.output=gr.Markdown(value="After enabling translation, please Wait until I am ready", visible=True) + label="Source language", + choices=sorted(language_options.keys()), + value="Chinese", + type="value", + elem_id=self.elem_id("x_type") + ) + self.output = gr.Markdown(value="After enabling translation, please Wait until I am ready", + visible=True) self.enable_translation.change( self.set_active, - [self.enable_translation], - [self.output, self.options], + [self.enable_translation], + [self.output, self.options], show_progress=True ) self.translate_negative_prompt.change( self.set_negative_translate_active, - [self.translate_negative_prompt], + [self.translate_negative_prompt], ) - self.options.visible=False + self.options.visible = False return [self.language] def get_prompts(self, p): @@ -433,47 +334,45 @@ class Script(scripts.Script): ) return original_prompts, original_negative_prompts - + def process(self, p, language, **kwargs): """Translates the prompts from a non-English language to English using the MBartTranslator object.""" if hasattr(self, "translator") and self.is_active: original_prompts, original_negative_prompts = self.get_prompts(p) - translated_prompts=[] + translated_prompts = [] previous_prompt = "" previous_translated_prompt = "" for original_prompt in original_prompts: if previous_prompt != original_prompt: - print(f"Translating prompt to English from {language_options[language].label}") + print(f"Translating prompt to English from {language}") print(f"Initial prompt:{original_prompt}") - ln_code = language_options[language].language_code + ln_code = language_options[language] translated_prompt = self.translator.translate(original_prompt, ln_code, "en_XX") translated_prompt = post_process_prompt(original_prompt, translated_prompt) print(f"Translated prompt:{translated_prompt}") translated_prompts.append(translated_prompt) - previous_prompt=original_prompt + previous_prompt = original_prompt previous_translated_prompt = translated_prompt else: translated_prompts.append(previous_translated_prompt) - - if p.negative_prompt!='' and self.is_negative_translate_active: + if p.negative_prompt != '' and self.is_negative_translate_active: previous_negative_prompt = "" previous_translated_negative_prompt = "" - translated_negative_prompts=[] + translated_negative_prompts = [] for negative_prompt in original_negative_prompts: - if previous_negative_prompt!=negative_prompt: - print(f"Translating negative prompt to English from {language_options[language].label}") + if previous_negative_prompt != negative_prompt: + print(f"Translating negative prompt to English from {language}") print(f"Initial negative prompt:{negative_prompt}") - ln_code = language_options[language].language_code + ln_code = language_options[language] translated_negative_prompt = self.translator.translate(negative_prompt, ln_code, "en_XX") - translated_negative_prompt = post_process_prompt(negative_prompt,translated_negative_prompt) + translated_negative_prompt = post_process_prompt(negative_prompt, translated_negative_prompt) print(f"Translated negative prompt:{translated_negative_prompt}") translated_negative_prompts.append(translated_negative_prompt) - previous_negative_prompt = negative_prompt previous_translated_negative_prompt = translated_negative_prompt else: @@ -483,4 +382,4 @@ class Script(scripts.Script): p.all_negative_prompts = translated_negative_prompts p.prompt = translated_prompts[0] p.prompt_for_display = translated_prompts[0] - p.all_prompts=translated_prompts + p.all_prompts = translated_prompts