Merge pull request #31 from Neokmi/main

Remove incorrect language options and fix to 52 officially supported languages.
main
Saifeddine ALOUI 2024-12-25 16:32:21 +01:00 committed by GitHub
commit d75e0fe584
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 122 additions and 223 deletions

View File

@ -3,17 +3,18 @@
# Description : This script translates Stable diffusion prompt from one of the 50 languages supported by MBART
# It uses MBartTranslator class that provides a simple interface for translating text using the MBart language model.
import modules.scripts as scripts
import gradio as gr
from modules.shared import opts
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
import re
import os
import re
import gradio as gr
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
import modules.scripts as scripts
# The directory to store the models
cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models')
class MBartTranslator:
"""MBartTranslator class provides a simple interface for translating text using the MBart language model.
@ -27,80 +28,12 @@ class MBartTranslator:
def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None):
self.supported_languages = [
"ar_AR",
"de_DE",
"en_XX",
"es_XX",
"fr_XX",
"hi_IN",
"it_IT",
"ja_XX",
"ko_XX",
"pt_XX",
"ru_RU",
"zh_XX",
"af_ZA",
"bn_BD",
"bs_XX",
"ca_XX",
"cs_CZ",
"da_XX",
"el_GR",
"et_EE",
"fa_IR",
"fi_FI",
"gu_IN",
"he_IL",
"hi_XX",
"hr_HR",
"hu_HU",
"id_ID",
"is_IS",
"ja_XX",
"jv_XX",
"ka_GE",
"kk_XX",
"km_KH",
"kn_IN",
"ko_KR",
"lo_LA",
"lt_LT",
"lv_LV",
"mk_MK",
"ml_IN",
"mr_IN",
"ms_MY",
"ne_NP",
"nl_XX",
"no_XX",
"pl_XX",
"ro_RO",
"si_LK",
"sk_SK",
"sl_SI",
"sq_AL",
"sr_XX",
"sv_XX",
"sw_TZ",
"ta_IN",
"te_IN",
"th_TH",
"tl_PH",
"tr_TR",
"uk_UA",
"ur_PK",
"vi_VN",
"war_PH",
"yue_XX",
"zh_CN",
"zh_TW",
]
print("Building translator")
print("Loading generator (this may take few minutes the first time as I need to download the model)")
self.model = MBartForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
print("Loading tokenizer")
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang, cache_dir=cache_dir)
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang,
cache_dir=cache_dir)
print("Translator is ready")
def translate(self, text: str, input_language: str, output_language: str) -> str:
@ -114,10 +47,12 @@ class MBartTranslator:
Returns:
str: The translated text.
"""
if input_language not in self.supported_languages:
raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}")
if output_language not in self.supported_languages:
raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}")
if input_language not in self.tokenizer.lang_code_to_id:
raise ValueError(
f"Input language not supported. Supported languages: {self.tokenizer.lang_code_to_id}")
if output_language not in self.tokenizer.lang_code_to_id:
raise ValueError(
f"Output language not supported. Supported languages: {self.tokenizer.lang_code_to_id}")
self.tokenizer.src_lang = input_language
encoded_input = self.tokenizer(text, return_tensors="pt")
@ -129,29 +64,6 @@ class MBartTranslator:
return translated_text[0]
class LanguageOption:
"""
A class representing a language option in a language selector.
Attributes:
label (str): The display label for the language option.
language_code (str): The ISO 639-1 language code for the language option.
"""
def __init__(self, label, language_code):
"""
Initializes a new LanguageOption instance.
Args:
label (str): The display label for the language option.
language_code (str): The ISO 639-1 language code for the language option.
"""
self.label = label
self.language_code = language_code
# This is a list of LanguageOption objects that represent the various language options available.
# Each LanguageOption object contains a label that represents the display name of the language and
# a language code that represents the code for the language that will be used by the translation model.
@ -160,75 +72,61 @@ class LanguageOption:
# For example, "en_XX" represents English language and "fr_FR" represents French language specific to France.
# These LanguageOption objects will be used to display the language options to the user and to retrieve the
# corresponding language code when the user selects a language.
language_options = [
LanguageOption("Arabic", "ar_AR"),
LanguageOption("Deutsch", "de_DE"),
LanguageOption("English", "en_XX"),
LanguageOption("Spanish", "es_XX"),
LanguageOption("French", "fr_XX"),
LanguageOption("Hindi", "hi_IN"),
LanguageOption("Italian", "it_IT"),
LanguageOption("Japanese", "ja_XX"),
LanguageOption("Korean", "ko_XX"),
LanguageOption("Portuguese", "pt_XX"),
LanguageOption("Russian", "ru_RU"),
LanguageOption("Chinese", "zh_XX"),
LanguageOption("Afrikaans", "af_ZA"),
LanguageOption("Bengali", "bn_BD"),
LanguageOption("Bosnian", "bs_XX"),
LanguageOption("Catalan", "ca_XX"),
LanguageOption("Czech", "cs_CZ"),
LanguageOption("Danish", "da_XX"),
LanguageOption("Greek", "el_GR"),
LanguageOption("Estonian", "et_EE"),
LanguageOption("Persian", "fa_IR"),
LanguageOption("Finnish", "fi_FI"),
LanguageOption("Gujarati", "gu_IN"),
LanguageOption("Hebrew", "he_IL"),
LanguageOption("Croatian", "hr_HR"),
LanguageOption("Hungarian", "hu_HU"),
LanguageOption("Indonesian", "id_ID"),
LanguageOption("Icelandic", "is_IS"),
LanguageOption("Javanese", "jv_XX"),
LanguageOption("Georgian", "ka_GE"),
LanguageOption("Kazakh", "kk_XX"),
LanguageOption("Khmer", "km_KH"),
LanguageOption("Kannada", "kn_IN"),
LanguageOption("Korean", "ko_KR"),
LanguageOption("Lao", "lo_LA"),
LanguageOption("Lithuanian", "lt_LT"),
LanguageOption("Latvian", "lv_LV"),
LanguageOption("Macedonian", "mk_MK"),
LanguageOption("Malayalam", "ml_IN"),
LanguageOption("Marathi", "mr_IN"),
LanguageOption("Malay", "ms_MY"),
LanguageOption("Nepali", "ne_NP"),
LanguageOption("Dutch", "nl_XX"),
LanguageOption("Norwegian", "no_XX"),
LanguageOption("Polish", "pl_XX"),
LanguageOption("Romanian", "ro_RO"),
LanguageOption("Sinhala", "si_LK"),
LanguageOption("Slovak", "sk_SK"),
LanguageOption("Slovenian", "sl_SI"),
LanguageOption("Albanian", "sq_AL"),
LanguageOption("Serbian", "sr_XX"),
LanguageOption("Swedish", "sv_XX"),
LanguageOption("Swahili", "sw_TZ"),
LanguageOption("Tamil", "ta_IN"),
LanguageOption("Telugu", "te_IN"),
LanguageOption("Tamil ", "ta_IN"),
LanguageOption("Telugu", "te_IN"),
LanguageOption("Thai", "th_TH"),
LanguageOption("Tagalog", "tl_PH"),
LanguageOption("Turkish", "tr_TR"),
LanguageOption("Ukrainian", "uk_UA"),
LanguageOption("Urdu", "ur_PK"),
LanguageOption("Vietnamese", "vi_VN"),
LanguageOption("Waray", "war_PH"),
LanguageOption("Cantonese", "yue_XX"),
LanguageOption("Chinese", "zh_CN"),
LanguageOption("Chinese", "zh_TW"),
]
language_options = {
'Afrikaans': 'af_ZA',
'Arabic': 'ar_AR',
'Azerbaijani (Azerbaijan)': 'az_AZ',
'Bengali (India)': 'bn_IN',
'Burmese (Myanmar)': 'my_MM',
'Chinese': 'zh_CN',
'Croatian': 'hr_HR',
'Czech': 'cs_CZ',
'Deutsch': 'de_DE',
"Dutch": "nl_XX",
'English': 'en_XX',
'Estonian': 'et_EE',
'Finnish': 'fi_FI',
'French': 'fr_XX',
'Galician (Spain)': 'gl_ES',
'Georgian': 'ka_GE',
'Gujarati': 'gu_IN',
'Hebrew': 'he_IL',
'Hindi': 'hi_IN',
'Indonesian': 'id_ID',
'Italian': 'it_IT',
'Japanese': 'ja_XX',
'Kazakh (Kazakhstan)': 'kk_KZ',
'Khmer': 'km_KH',
'Korean': 'ko_KR',
'Latvian': 'lv_LV',
'Lithuanian': 'lt_LT',
'Macedonian': 'mk_MK',
'Malayalam': 'ml_IN',
'Marathi': 'mr_IN',
'Mongolian (Mongolia)': 'mn_MN',
'Nepali': 'ne_NP',
'Pashto (Afghanistan)': 'ps_AF',
'Persian': 'fa_IR',
'Polish (Poland)': 'pl_PL',
'Portuguese': 'pt_XX',
'Romanian': 'ro_RO',
'Russian': 'ru_RU',
'Sinhala': 'si_LK',
'Slovenian': 'sl_SI',
'Spanish': 'es_XX',
'Swahili (Kenya)': 'sw_KE',
'Swedish (Sweden)': 'sv_SE',
'Tamil': 'ta_IN',
'Telugu': 'te_IN',
"Tagalog": "tl_XX",
'Thai': 'th_TH',
'Turkish': 'tr_TR',
'Ukrainian': 'uk_UA',
'Urdu': 'ur_PK',
'Vietnamese': 'vi_VN',
'Xhosa (South Africa)': 'xh_ZA'
}
def remove_unnecessary_spaces(text):
"""Removes unnecessary spaces between characters."""
@ -236,28 +134,30 @@ def remove_unnecessary_spaces(text):
replacement = r")++"
return re.sub(pattern, replacement, text)
def correct_translation_format(original_text, translated_text):
original_parts = original_text.split('++')
translated_parts = translated_text.split('++')
corrected_parts = []
for i, original_part in enumerate(original_parts):
translated_part = translated_parts[i]
original_plus_count = original_part.count('+')
translated_plus_count = translated_part.count('+')
plus_difference = translated_plus_count - original_plus_count
if plus_difference > 0:
translated_part = translated_part.replace('+' * plus_difference, '', 1)
elif plus_difference < 0:
translated_part += '+' * abs(plus_difference)
corrected_parts.append(translated_part)
corrected_text = '++'.join(corrected_parts)
return corrected_text
def extract_plus_positions(text):
"""
Given a string of text, extracts the positions of all sequences of one or more '+' characters.
@ -289,7 +189,7 @@ def extract_plus_positions(text):
positions.append([j, last_match_end, last_match_end - j])
last_match_end = match.end()
# If the final match extends to the end of the string, add its position to the output list
if last_match_end is not None and last_match_end == len(text):
j = last_match_end - 1
@ -314,11 +214,11 @@ def match_pluses(original_text, translated_text):
- output (str): the translated text with '+' characters replaced by those in the original text
"""
in_positions = extract_plus_positions(original_text)
out_positions = extract_plus_positions(translated_text)
out_positions = extract_plus_positions(translated_text)
out_vals = []
out_current_pos = 0
if len(in_positions) == len(out_positions):
# Iterate through the positions and replace the sequences of '+' characters in the translated text
# with those in the original text
@ -326,31 +226,33 @@ def match_pluses(original_text, translated_text):
out_vals.append(translated_text[out_current_pos:out_[0]])
out_vals.append(original_text[in_[0]:in_[1]])
out_current_pos = out_[1]
# Check that the number of '+' characters in the original and translated sequences is the same
if in_[2] != out_[2]:
print("detected different + count")
# Add any remaining text from the translated string to the output
out_vals.append(translated_text[out_current_pos:])
# Join the output values into a single string
output = "".join(out_vals)
return output
def post_process_prompt(original, translated):
"""Applies post-processing to the translated prompt such as removing unnecessary spaces and extra plus signs."""
clean_prompt = remove_unnecessary_spaces(translated)
clean_prompt = match_pluses(original, clean_prompt)
#clean_prompt = remove_extra_plus(clean_prompt)
return clean_prompt
# clean_prompt = remove_extra_plus(clean_prompt)
return clean_prompt
class Script(scripts.Script):
def __init__(self) -> None:
"""Initializes the Script class and sets the default value for enable_translation attribute."""
super().__init__()
self.enable_translation=False
self.is_negative_translate_active=False
self.enable_translation = False
self.is_negative_translate_active = False
def title(self):
"""Returns the title of the script."""
@ -359,11 +261,11 @@ class Script(scripts.Script):
def show(self, is_img2img):
"""Returns the visibility status of the script in the interface."""
return scripts.AlwaysVisible
def set_active(self, active):
"""Sets the is_active attribute and initializes the translator object if not already created.
Also, sets the visibility of the language dropdown to True."""
self.is_active=active
self.is_active = active
if not hasattr(self, "translator"):
self.translator = MBartTranslator()
if self.is_active:
@ -374,19 +276,17 @@ class Script(scripts.Script):
def set_negative_translate_active(self, negative_translate_active):
"""Sets the is_active attribute and initializes the translator object if not already created.
Also, sets the visibility of the language dropdown to True."""
self.is_negative_translate_active=negative_translate_active
self.is_negative_translate_active = negative_translate_active
def ui(self, is_img2img):
"""Sets up the user interface of the script."""
self.is_active=False
self.current_axis_options = [x for x in language_options]
self.is_active = False
with gr.Row():
with gr.Column(scale=19):
with gr.Accordion("Prompt Translator",open=False):
with gr.Accordion("Help",open=False):
md = gr.Markdown("""
with gr.Accordion("Prompt Translator", open=False):
with gr.Accordion("Help", open=False):
gr.Markdown("""
# Description
This script translates your prompt from another language to english before generating the image allowing you to write prompts in your native language.
# How to use
@ -398,29 +298,30 @@ class Script(scripts.Script):
with gr.Column():
self.enable_translation = gr.Checkbox(label="Enable translation")
with gr.Column() as options:
self.options=options
self.options = options
self.translate_negative_prompt = gr.Checkbox(label="Translate negative prompt")
self.enable_translation.value=False
self.enable_translation.value = False
self.language = gr.Dropdown(
label="Source language",
choices=[x.label for x in self.current_axis_options],
value="Français",
type="index",
elem_id=self.elem_id("x_type")
)
self.output=gr.Markdown(value="After enabling translation, please Wait until I am ready", visible=True)
label="Source language",
choices=sorted(language_options.keys()),
value="Chinese",
type="value",
elem_id=self.elem_id("x_type")
)
self.output = gr.Markdown(value="After enabling translation, please Wait until I am ready",
visible=True)
self.enable_translation.change(
self.set_active,
[self.enable_translation],
[self.output, self.options],
[self.enable_translation],
[self.output, self.options],
show_progress=True
)
self.translate_negative_prompt.change(
self.set_negative_translate_active,
[self.translate_negative_prompt],
[self.translate_negative_prompt],
)
self.options.visible=False
self.options.visible = False
return [self.language]
def get_prompts(self, p):
@ -433,47 +334,45 @@ class Script(scripts.Script):
)
return original_prompts, original_negative_prompts
def process(self, p, language, **kwargs):
"""Translates the prompts from a non-English language to English using the MBartTranslator object."""
if hasattr(self, "translator") and self.is_active:
original_prompts, original_negative_prompts = self.get_prompts(p)
translated_prompts=[]
translated_prompts = []
previous_prompt = ""
previous_translated_prompt = ""
for original_prompt in original_prompts:
if previous_prompt != original_prompt:
print(f"Translating prompt to English from {language_options[language].label}")
print(f"Translating prompt to English from {language}")
print(f"Initial prompt:{original_prompt}")
ln_code = language_options[language].language_code
ln_code = language_options[language]
translated_prompt = self.translator.translate(original_prompt, ln_code, "en_XX")
translated_prompt = post_process_prompt(original_prompt, translated_prompt)
print(f"Translated prompt:{translated_prompt}")
translated_prompts.append(translated_prompt)
previous_prompt=original_prompt
previous_prompt = original_prompt
previous_translated_prompt = translated_prompt
else:
translated_prompts.append(previous_translated_prompt)
if p.negative_prompt!='' and self.is_negative_translate_active:
if p.negative_prompt != '' and self.is_negative_translate_active:
previous_negative_prompt = ""
previous_translated_negative_prompt = ""
translated_negative_prompts=[]
translated_negative_prompts = []
for negative_prompt in original_negative_prompts:
if previous_negative_prompt!=negative_prompt:
print(f"Translating negative prompt to English from {language_options[language].label}")
if previous_negative_prompt != negative_prompt:
print(f"Translating negative prompt to English from {language}")
print(f"Initial negative prompt:{negative_prompt}")
ln_code = language_options[language].language_code
ln_code = language_options[language]
translated_negative_prompt = self.translator.translate(negative_prompt, ln_code, "en_XX")
translated_negative_prompt = post_process_prompt(negative_prompt,translated_negative_prompt)
translated_negative_prompt = post_process_prompt(negative_prompt, translated_negative_prompt)
print(f"Translated negative prompt:{translated_negative_prompt}")
translated_negative_prompts.append(translated_negative_prompt)
previous_negative_prompt = negative_prompt
previous_translated_negative_prompt = translated_negative_prompt
else:
@ -483,4 +382,4 @@ class Script(scripts.Script):
p.all_negative_prompts = translated_negative_prompts
p.prompt = translated_prompts[0]
p.prompt_for_display = translated_prompts[0]
p.all_prompts=translated_prompts
p.all_prompts = translated_prompts