Merge pull request #31 from Neokmi/main
Remove incorrect language options and fix to 52 officially supported languages.main
commit
d75e0fe584
345
scripts/main.py
345
scripts/main.py
|
|
@ -3,17 +3,18 @@
|
|||
# Description : This script translates Stable diffusion prompt from one of the 50 languages supported by MBART
|
||||
# It uses MBartTranslator class that provides a simple interface for translating text using the MBart language model.
|
||||
|
||||
import modules.scripts as scripts
|
||||
import gradio as gr
|
||||
from modules.shared import opts
|
||||
|
||||
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
|
||||
import re
|
||||
import os
|
||||
import re
|
||||
|
||||
import gradio as gr
|
||||
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
|
||||
|
||||
import modules.scripts as scripts
|
||||
|
||||
# The directory to store the models
|
||||
cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models')
|
||||
|
||||
|
||||
class MBartTranslator:
|
||||
"""MBartTranslator class provides a simple interface for translating text using the MBart language model.
|
||||
|
||||
|
|
@ -27,80 +28,12 @@ class MBartTranslator:
|
|||
|
||||
def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None):
|
||||
|
||||
self.supported_languages = [
|
||||
"ar_AR",
|
||||
"de_DE",
|
||||
"en_XX",
|
||||
"es_XX",
|
||||
"fr_XX",
|
||||
"hi_IN",
|
||||
"it_IT",
|
||||
"ja_XX",
|
||||
"ko_XX",
|
||||
"pt_XX",
|
||||
"ru_RU",
|
||||
"zh_XX",
|
||||
"af_ZA",
|
||||
"bn_BD",
|
||||
"bs_XX",
|
||||
"ca_XX",
|
||||
"cs_CZ",
|
||||
"da_XX",
|
||||
"el_GR",
|
||||
"et_EE",
|
||||
"fa_IR",
|
||||
"fi_FI",
|
||||
"gu_IN",
|
||||
"he_IL",
|
||||
"hi_XX",
|
||||
"hr_HR",
|
||||
"hu_HU",
|
||||
"id_ID",
|
||||
"is_IS",
|
||||
"ja_XX",
|
||||
"jv_XX",
|
||||
"ka_GE",
|
||||
"kk_XX",
|
||||
"km_KH",
|
||||
"kn_IN",
|
||||
"ko_KR",
|
||||
"lo_LA",
|
||||
"lt_LT",
|
||||
"lv_LV",
|
||||
"mk_MK",
|
||||
"ml_IN",
|
||||
"mr_IN",
|
||||
"ms_MY",
|
||||
"ne_NP",
|
||||
"nl_XX",
|
||||
"no_XX",
|
||||
"pl_XX",
|
||||
"ro_RO",
|
||||
"si_LK",
|
||||
"sk_SK",
|
||||
"sl_SI",
|
||||
"sq_AL",
|
||||
"sr_XX",
|
||||
"sv_XX",
|
||||
"sw_TZ",
|
||||
"ta_IN",
|
||||
"te_IN",
|
||||
"th_TH",
|
||||
"tl_PH",
|
||||
"tr_TR",
|
||||
"uk_UA",
|
||||
"ur_PK",
|
||||
"vi_VN",
|
||||
"war_PH",
|
||||
"yue_XX",
|
||||
"zh_CN",
|
||||
"zh_TW",
|
||||
]
|
||||
print("Building translator")
|
||||
print("Loading generator (this may take few minutes the first time as I need to download the model)")
|
||||
self.model = MBartForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
print("Loading tokenizer")
|
||||
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang, cache_dir=cache_dir)
|
||||
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang,
|
||||
cache_dir=cache_dir)
|
||||
print("Translator is ready")
|
||||
|
||||
def translate(self, text: str, input_language: str, output_language: str) -> str:
|
||||
|
|
@ -114,10 +47,12 @@ class MBartTranslator:
|
|||
Returns:
|
||||
str: The translated text.
|
||||
"""
|
||||
if input_language not in self.supported_languages:
|
||||
raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}")
|
||||
if output_language not in self.supported_languages:
|
||||
raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}")
|
||||
if input_language not in self.tokenizer.lang_code_to_id:
|
||||
raise ValueError(
|
||||
f"Input language not supported. Supported languages: {self.tokenizer.lang_code_to_id}")
|
||||
if output_language not in self.tokenizer.lang_code_to_id:
|
||||
raise ValueError(
|
||||
f"Output language not supported. Supported languages: {self.tokenizer.lang_code_to_id}")
|
||||
|
||||
self.tokenizer.src_lang = input_language
|
||||
encoded_input = self.tokenizer(text, return_tensors="pt")
|
||||
|
|
@ -129,29 +64,6 @@ class MBartTranslator:
|
|||
return translated_text[0]
|
||||
|
||||
|
||||
|
||||
class LanguageOption:
|
||||
"""
|
||||
A class representing a language option in a language selector.
|
||||
|
||||
Attributes:
|
||||
label (str): The display label for the language option.
|
||||
language_code (str): The ISO 639-1 language code for the language option.
|
||||
"""
|
||||
|
||||
def __init__(self, label, language_code):
|
||||
"""
|
||||
Initializes a new LanguageOption instance.
|
||||
|
||||
Args:
|
||||
label (str): The display label for the language option.
|
||||
language_code (str): The ISO 639-1 language code for the language option.
|
||||
"""
|
||||
self.label = label
|
||||
self.language_code = language_code
|
||||
|
||||
|
||||
|
||||
# This is a list of LanguageOption objects that represent the various language options available.
|
||||
# Each LanguageOption object contains a label that represents the display name of the language and
|
||||
# a language code that represents the code for the language that will be used by the translation model.
|
||||
|
|
@ -160,75 +72,61 @@ class LanguageOption:
|
|||
# For example, "en_XX" represents English language and "fr_FR" represents French language specific to France.
|
||||
# These LanguageOption objects will be used to display the language options to the user and to retrieve the
|
||||
# corresponding language code when the user selects a language.
|
||||
language_options = [
|
||||
LanguageOption("Arabic", "ar_AR"),
|
||||
LanguageOption("Deutsch", "de_DE"),
|
||||
LanguageOption("English", "en_XX"),
|
||||
LanguageOption("Spanish", "es_XX"),
|
||||
LanguageOption("French", "fr_XX"),
|
||||
LanguageOption("Hindi", "hi_IN"),
|
||||
LanguageOption("Italian", "it_IT"),
|
||||
LanguageOption("Japanese", "ja_XX"),
|
||||
LanguageOption("Korean", "ko_XX"),
|
||||
LanguageOption("Portuguese", "pt_XX"),
|
||||
LanguageOption("Russian", "ru_RU"),
|
||||
LanguageOption("Chinese", "zh_XX"),
|
||||
LanguageOption("Afrikaans", "af_ZA"),
|
||||
LanguageOption("Bengali", "bn_BD"),
|
||||
LanguageOption("Bosnian", "bs_XX"),
|
||||
LanguageOption("Catalan", "ca_XX"),
|
||||
LanguageOption("Czech", "cs_CZ"),
|
||||
LanguageOption("Danish", "da_XX"),
|
||||
LanguageOption("Greek", "el_GR"),
|
||||
LanguageOption("Estonian", "et_EE"),
|
||||
LanguageOption("Persian", "fa_IR"),
|
||||
LanguageOption("Finnish", "fi_FI"),
|
||||
LanguageOption("Gujarati", "gu_IN"),
|
||||
LanguageOption("Hebrew", "he_IL"),
|
||||
LanguageOption("Croatian", "hr_HR"),
|
||||
LanguageOption("Hungarian", "hu_HU"),
|
||||
LanguageOption("Indonesian", "id_ID"),
|
||||
LanguageOption("Icelandic", "is_IS"),
|
||||
LanguageOption("Javanese", "jv_XX"),
|
||||
LanguageOption("Georgian", "ka_GE"),
|
||||
LanguageOption("Kazakh", "kk_XX"),
|
||||
LanguageOption("Khmer", "km_KH"),
|
||||
LanguageOption("Kannada", "kn_IN"),
|
||||
LanguageOption("Korean", "ko_KR"),
|
||||
LanguageOption("Lao", "lo_LA"),
|
||||
LanguageOption("Lithuanian", "lt_LT"),
|
||||
LanguageOption("Latvian", "lv_LV"),
|
||||
LanguageOption("Macedonian", "mk_MK"),
|
||||
LanguageOption("Malayalam", "ml_IN"),
|
||||
LanguageOption("Marathi", "mr_IN"),
|
||||
LanguageOption("Malay", "ms_MY"),
|
||||
LanguageOption("Nepali", "ne_NP"),
|
||||
LanguageOption("Dutch", "nl_XX"),
|
||||
LanguageOption("Norwegian", "no_XX"),
|
||||
LanguageOption("Polish", "pl_XX"),
|
||||
LanguageOption("Romanian", "ro_RO"),
|
||||
LanguageOption("Sinhala", "si_LK"),
|
||||
LanguageOption("Slovak", "sk_SK"),
|
||||
LanguageOption("Slovenian", "sl_SI"),
|
||||
LanguageOption("Albanian", "sq_AL"),
|
||||
LanguageOption("Serbian", "sr_XX"),
|
||||
LanguageOption("Swedish", "sv_XX"),
|
||||
LanguageOption("Swahili", "sw_TZ"),
|
||||
LanguageOption("Tamil", "ta_IN"),
|
||||
LanguageOption("Telugu", "te_IN"),
|
||||
LanguageOption("Tamil ", "ta_IN"),
|
||||
LanguageOption("Telugu", "te_IN"),
|
||||
LanguageOption("Thai", "th_TH"),
|
||||
LanguageOption("Tagalog", "tl_PH"),
|
||||
LanguageOption("Turkish", "tr_TR"),
|
||||
LanguageOption("Ukrainian", "uk_UA"),
|
||||
LanguageOption("Urdu", "ur_PK"),
|
||||
LanguageOption("Vietnamese", "vi_VN"),
|
||||
LanguageOption("Waray", "war_PH"),
|
||||
LanguageOption("Cantonese", "yue_XX"),
|
||||
LanguageOption("Chinese", "zh_CN"),
|
||||
LanguageOption("Chinese", "zh_TW"),
|
||||
]
|
||||
language_options = {
|
||||
'Afrikaans': 'af_ZA',
|
||||
'Arabic': 'ar_AR',
|
||||
'Azerbaijani (Azerbaijan)': 'az_AZ',
|
||||
'Bengali (India)': 'bn_IN',
|
||||
'Burmese (Myanmar)': 'my_MM',
|
||||
'Chinese': 'zh_CN',
|
||||
'Croatian': 'hr_HR',
|
||||
'Czech': 'cs_CZ',
|
||||
'Deutsch': 'de_DE',
|
||||
"Dutch": "nl_XX",
|
||||
'English': 'en_XX',
|
||||
'Estonian': 'et_EE',
|
||||
'Finnish': 'fi_FI',
|
||||
'French': 'fr_XX',
|
||||
'Galician (Spain)': 'gl_ES',
|
||||
'Georgian': 'ka_GE',
|
||||
'Gujarati': 'gu_IN',
|
||||
'Hebrew': 'he_IL',
|
||||
'Hindi': 'hi_IN',
|
||||
'Indonesian': 'id_ID',
|
||||
'Italian': 'it_IT',
|
||||
'Japanese': 'ja_XX',
|
||||
'Kazakh (Kazakhstan)': 'kk_KZ',
|
||||
'Khmer': 'km_KH',
|
||||
'Korean': 'ko_KR',
|
||||
'Latvian': 'lv_LV',
|
||||
'Lithuanian': 'lt_LT',
|
||||
'Macedonian': 'mk_MK',
|
||||
'Malayalam': 'ml_IN',
|
||||
'Marathi': 'mr_IN',
|
||||
'Mongolian (Mongolia)': 'mn_MN',
|
||||
'Nepali': 'ne_NP',
|
||||
'Pashto (Afghanistan)': 'ps_AF',
|
||||
'Persian': 'fa_IR',
|
||||
'Polish (Poland)': 'pl_PL',
|
||||
'Portuguese': 'pt_XX',
|
||||
'Romanian': 'ro_RO',
|
||||
'Russian': 'ru_RU',
|
||||
'Sinhala': 'si_LK',
|
||||
'Slovenian': 'sl_SI',
|
||||
'Spanish': 'es_XX',
|
||||
'Swahili (Kenya)': 'sw_KE',
|
||||
'Swedish (Sweden)': 'sv_SE',
|
||||
'Tamil': 'ta_IN',
|
||||
'Telugu': 'te_IN',
|
||||
"Tagalog": "tl_XX",
|
||||
'Thai': 'th_TH',
|
||||
'Turkish': 'tr_TR',
|
||||
'Ukrainian': 'uk_UA',
|
||||
'Urdu': 'ur_PK',
|
||||
'Vietnamese': 'vi_VN',
|
||||
'Xhosa (South Africa)': 'xh_ZA'
|
||||
}
|
||||
|
||||
|
||||
def remove_unnecessary_spaces(text):
|
||||
"""Removes unnecessary spaces between characters."""
|
||||
|
|
@ -236,28 +134,30 @@ def remove_unnecessary_spaces(text):
|
|||
replacement = r")++"
|
||||
return re.sub(pattern, replacement, text)
|
||||
|
||||
|
||||
def correct_translation_format(original_text, translated_text):
|
||||
original_parts = original_text.split('++')
|
||||
translated_parts = translated_text.split('++')
|
||||
|
||||
|
||||
corrected_parts = []
|
||||
for i, original_part in enumerate(original_parts):
|
||||
translated_part = translated_parts[i]
|
||||
|
||||
|
||||
original_plus_count = original_part.count('+')
|
||||
translated_plus_count = translated_part.count('+')
|
||||
plus_difference = translated_plus_count - original_plus_count
|
||||
|
||||
|
||||
if plus_difference > 0:
|
||||
translated_part = translated_part.replace('+' * plus_difference, '', 1)
|
||||
elif plus_difference < 0:
|
||||
translated_part += '+' * abs(plus_difference)
|
||||
|
||||
|
||||
corrected_parts.append(translated_part)
|
||||
|
||||
|
||||
corrected_text = '++'.join(corrected_parts)
|
||||
return corrected_text
|
||||
|
||||
|
||||
def extract_plus_positions(text):
|
||||
"""
|
||||
Given a string of text, extracts the positions of all sequences of one or more '+' characters.
|
||||
|
|
@ -289,7 +189,7 @@ def extract_plus_positions(text):
|
|||
positions.append([j, last_match_end, last_match_end - j])
|
||||
|
||||
last_match_end = match.end()
|
||||
|
||||
|
||||
# If the final match extends to the end of the string, add its position to the output list
|
||||
if last_match_end is not None and last_match_end == len(text):
|
||||
j = last_match_end - 1
|
||||
|
|
@ -314,11 +214,11 @@ def match_pluses(original_text, translated_text):
|
|||
- output (str): the translated text with '+' characters replaced by those in the original text
|
||||
"""
|
||||
in_positions = extract_plus_positions(original_text)
|
||||
out_positions = extract_plus_positions(translated_text)
|
||||
|
||||
out_positions = extract_plus_positions(translated_text)
|
||||
|
||||
out_vals = []
|
||||
out_current_pos = 0
|
||||
|
||||
|
||||
if len(in_positions) == len(out_positions):
|
||||
# Iterate through the positions and replace the sequences of '+' characters in the translated text
|
||||
# with those in the original text
|
||||
|
|
@ -326,31 +226,33 @@ def match_pluses(original_text, translated_text):
|
|||
out_vals.append(translated_text[out_current_pos:out_[0]])
|
||||
out_vals.append(original_text[in_[0]:in_[1]])
|
||||
out_current_pos = out_[1]
|
||||
|
||||
|
||||
# Check that the number of '+' characters in the original and translated sequences is the same
|
||||
if in_[2] != out_[2]:
|
||||
print("detected different + count")
|
||||
|
||||
# Add any remaining text from the translated string to the output
|
||||
out_vals.append(translated_text[out_current_pos:])
|
||||
|
||||
|
||||
# Join the output values into a single string
|
||||
output = "".join(out_vals)
|
||||
return output
|
||||
|
||||
|
||||
def post_process_prompt(original, translated):
|
||||
"""Applies post-processing to the translated prompt such as removing unnecessary spaces and extra plus signs."""
|
||||
clean_prompt = remove_unnecessary_spaces(translated)
|
||||
clean_prompt = match_pluses(original, clean_prompt)
|
||||
#clean_prompt = remove_extra_plus(clean_prompt)
|
||||
return clean_prompt
|
||||
# clean_prompt = remove_extra_plus(clean_prompt)
|
||||
return clean_prompt
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
def __init__(self) -> None:
|
||||
"""Initializes the Script class and sets the default value for enable_translation attribute."""
|
||||
super().__init__()
|
||||
self.enable_translation=False
|
||||
self.is_negative_translate_active=False
|
||||
self.enable_translation = False
|
||||
self.is_negative_translate_active = False
|
||||
|
||||
def title(self):
|
||||
"""Returns the title of the script."""
|
||||
|
|
@ -359,11 +261,11 @@ class Script(scripts.Script):
|
|||
def show(self, is_img2img):
|
||||
"""Returns the visibility status of the script in the interface."""
|
||||
return scripts.AlwaysVisible
|
||||
|
||||
|
||||
def set_active(self, active):
|
||||
"""Sets the is_active attribute and initializes the translator object if not already created.
|
||||
Also, sets the visibility of the language dropdown to True."""
|
||||
self.is_active=active
|
||||
self.is_active = active
|
||||
if not hasattr(self, "translator"):
|
||||
self.translator = MBartTranslator()
|
||||
if self.is_active:
|
||||
|
|
@ -374,19 +276,17 @@ class Script(scripts.Script):
|
|||
def set_negative_translate_active(self, negative_translate_active):
|
||||
"""Sets the is_active attribute and initializes the translator object if not already created.
|
||||
Also, sets the visibility of the language dropdown to True."""
|
||||
self.is_negative_translate_active=negative_translate_active
|
||||
|
||||
self.is_negative_translate_active = negative_translate_active
|
||||
|
||||
def ui(self, is_img2img):
|
||||
"""Sets up the user interface of the script."""
|
||||
self.is_active=False
|
||||
self.current_axis_options = [x for x in language_options]
|
||||
self.is_active = False
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=19):
|
||||
with gr.Accordion("Prompt Translator",open=False):
|
||||
with gr.Accordion("Help",open=False):
|
||||
md = gr.Markdown("""
|
||||
with gr.Accordion("Prompt Translator", open=False):
|
||||
with gr.Accordion("Help", open=False):
|
||||
gr.Markdown("""
|
||||
# Description
|
||||
This script translates your prompt from another language to english before generating the image allowing you to write prompts in your native language.
|
||||
# How to use
|
||||
|
|
@ -398,29 +298,30 @@ class Script(scripts.Script):
|
|||
with gr.Column():
|
||||
self.enable_translation = gr.Checkbox(label="Enable translation")
|
||||
with gr.Column() as options:
|
||||
self.options=options
|
||||
self.options = options
|
||||
self.translate_negative_prompt = gr.Checkbox(label="Translate negative prompt")
|
||||
self.enable_translation.value=False
|
||||
self.enable_translation.value = False
|
||||
self.language = gr.Dropdown(
|
||||
label="Source language",
|
||||
choices=[x.label for x in self.current_axis_options],
|
||||
value="Français",
|
||||
type="index",
|
||||
elem_id=self.elem_id("x_type")
|
||||
)
|
||||
self.output=gr.Markdown(value="After enabling translation, please Wait until I am ready", visible=True)
|
||||
label="Source language",
|
||||
choices=sorted(language_options.keys()),
|
||||
value="Chinese",
|
||||
type="value",
|
||||
elem_id=self.elem_id("x_type")
|
||||
)
|
||||
self.output = gr.Markdown(value="After enabling translation, please Wait until I am ready",
|
||||
visible=True)
|
||||
self.enable_translation.change(
|
||||
self.set_active,
|
||||
[self.enable_translation],
|
||||
[self.output, self.options],
|
||||
[self.enable_translation],
|
||||
[self.output, self.options],
|
||||
show_progress=True
|
||||
)
|
||||
self.translate_negative_prompt.change(
|
||||
self.set_negative_translate_active,
|
||||
[self.translate_negative_prompt],
|
||||
[self.translate_negative_prompt],
|
||||
)
|
||||
|
||||
self.options.visible=False
|
||||
self.options.visible = False
|
||||
return [self.language]
|
||||
|
||||
def get_prompts(self, p):
|
||||
|
|
@ -433,47 +334,45 @@ class Script(scripts.Script):
|
|||
)
|
||||
|
||||
return original_prompts, original_negative_prompts
|
||||
|
||||
|
||||
def process(self, p, language, **kwargs):
|
||||
"""Translates the prompts from a non-English language to English using the MBartTranslator object."""
|
||||
|
||||
if hasattr(self, "translator") and self.is_active:
|
||||
original_prompts, original_negative_prompts = self.get_prompts(p)
|
||||
translated_prompts=[]
|
||||
translated_prompts = []
|
||||
previous_prompt = ""
|
||||
previous_translated_prompt = ""
|
||||
|
||||
for original_prompt in original_prompts:
|
||||
if previous_prompt != original_prompt:
|
||||
print(f"Translating prompt to English from {language_options[language].label}")
|
||||
print(f"Translating prompt to English from {language}")
|
||||
print(f"Initial prompt:{original_prompt}")
|
||||
ln_code = language_options[language].language_code
|
||||
ln_code = language_options[language]
|
||||
translated_prompt = self.translator.translate(original_prompt, ln_code, "en_XX")
|
||||
translated_prompt = post_process_prompt(original_prompt, translated_prompt)
|
||||
print(f"Translated prompt:{translated_prompt}")
|
||||
translated_prompts.append(translated_prompt)
|
||||
|
||||
previous_prompt=original_prompt
|
||||
previous_prompt = original_prompt
|
||||
previous_translated_prompt = translated_prompt
|
||||
else:
|
||||
translated_prompts.append(previous_translated_prompt)
|
||||
|
||||
|
||||
if p.negative_prompt!='' and self.is_negative_translate_active:
|
||||
if p.negative_prompt != '' and self.is_negative_translate_active:
|
||||
previous_negative_prompt = ""
|
||||
previous_translated_negative_prompt = ""
|
||||
translated_negative_prompts=[]
|
||||
translated_negative_prompts = []
|
||||
for negative_prompt in original_negative_prompts:
|
||||
if previous_negative_prompt!=negative_prompt:
|
||||
print(f"Translating negative prompt to English from {language_options[language].label}")
|
||||
if previous_negative_prompt != negative_prompt:
|
||||
print(f"Translating negative prompt to English from {language}")
|
||||
print(f"Initial negative prompt:{negative_prompt}")
|
||||
ln_code = language_options[language].language_code
|
||||
ln_code = language_options[language]
|
||||
translated_negative_prompt = self.translator.translate(negative_prompt, ln_code, "en_XX")
|
||||
translated_negative_prompt = post_process_prompt(negative_prompt,translated_negative_prompt)
|
||||
translated_negative_prompt = post_process_prompt(negative_prompt, translated_negative_prompt)
|
||||
print(f"Translated negative prompt:{translated_negative_prompt}")
|
||||
translated_negative_prompts.append(translated_negative_prompt)
|
||||
|
||||
|
||||
previous_negative_prompt = negative_prompt
|
||||
previous_translated_negative_prompt = translated_negative_prompt
|
||||
else:
|
||||
|
|
@ -483,4 +382,4 @@ class Script(scripts.Script):
|
|||
p.all_negative_prompts = translated_negative_prompts
|
||||
p.prompt = translated_prompts[0]
|
||||
p.prompt_for_display = translated_prompts[0]
|
||||
p.all_prompts=translated_prompts
|
||||
p.all_prompts = translated_prompts
|
||||
|
|
|
|||
Loading…
Reference in New Issue