Update main.py

Remove incorrect language options and fix to 52 officially supported languages.
pull/31/head
klx 2024-12-25 23:07:31 +08:00 committed by GitHub
parent c3eb5039f4
commit d8babcfc55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 122 additions and 223 deletions

View File

@ -3,17 +3,18 @@
# Description : This script translates Stable diffusion prompt from one of the 50 languages supported by MBART
# It uses MBartTranslator class that provides a simple interface for translating text using the MBart language model.
import modules.scripts as scripts
import gradio as gr
from modules.shared import opts
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
import re
import os
import re
import gradio as gr
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
import modules.scripts as scripts
# The directory to store the models
cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models')
class MBartTranslator:
"""MBartTranslator class provides a simple interface for translating text using the MBart language model.
@ -27,80 +28,12 @@ class MBartTranslator:
def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None):
self.supported_languages = [
"ar_AR",
"de_DE",
"en_XX",
"es_XX",
"fr_XX",
"hi_IN",
"it_IT",
"ja_XX",
"ko_XX",
"pt_XX",
"ru_RU",
"zh_XX",
"af_ZA",
"bn_BD",
"bs_XX",
"ca_XX",
"cs_CZ",
"da_XX",
"el_GR",
"et_EE",
"fa_IR",
"fi_FI",
"gu_IN",
"he_IL",
"hi_XX",
"hr_HR",
"hu_HU",
"id_ID",
"is_IS",
"ja_XX",
"jv_XX",
"ka_GE",
"kk_XX",
"km_KH",
"kn_IN",
"ko_KR",
"lo_LA",
"lt_LT",
"lv_LV",
"mk_MK",
"ml_IN",
"mr_IN",
"ms_MY",
"ne_NP",
"nl_XX",
"no_XX",
"pl_XX",
"ro_RO",
"si_LK",
"sk_SK",
"sl_SI",
"sq_AL",
"sr_XX",
"sv_XX",
"sw_TZ",
"ta_IN",
"te_IN",
"th_TH",
"tl_PH",
"tr_TR",
"uk_UA",
"ur_PK",
"vi_VN",
"war_PH",
"yue_XX",
"zh_CN",
"zh_TW",
]
print("Building translator")
print("Loading generator (this may take few minutes the first time as I need to download the model)")
self.model = MBartForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
print("Loading tokenizer")
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang, cache_dir=cache_dir)
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang,
cache_dir=cache_dir)
print("Translator is ready")
def translate(self, text: str, input_language: str, output_language: str) -> str:
@ -114,10 +47,12 @@ class MBartTranslator:
Returns:
str: The translated text.
"""
if input_language not in self.supported_languages:
raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}")
if output_language not in self.supported_languages:
raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}")
if input_language not in self.tokenizer.lang_code_to_id:
raise ValueError(
f"Input language not supported. Supported languages: {self.tokenizer.lang_code_to_id}")
if output_language not in self.tokenizer.lang_code_to_id:
raise ValueError(
f"Output language not supported. Supported languages: {self.tokenizer.lang_code_to_id}")
self.tokenizer.src_lang = input_language
encoded_input = self.tokenizer(text, return_tensors="pt")
@ -129,29 +64,6 @@ class MBartTranslator:
return translated_text[0]
class LanguageOption:
"""
A class representing a language option in a language selector.
Attributes:
label (str): The display label for the language option.
language_code (str): The ISO 639-1 language code for the language option.
"""
def __init__(self, label, language_code):
"""
Initializes a new LanguageOption instance.
Args:
label (str): The display label for the language option.
language_code (str): The ISO 639-1 language code for the language option.
"""
self.label = label
self.language_code = language_code
# This is a list of LanguageOption objects that represent the various language options available.
# Each LanguageOption object contains a label that represents the display name of the language and
# a language code that represents the code for the language that will be used by the translation model.
@ -160,75 +72,61 @@ class LanguageOption:
# For example, "en_XX" represents English language and "fr_FR" represents French language specific to France.
# These LanguageOption objects will be used to display the language options to the user and to retrieve the
# corresponding language code when the user selects a language.
language_options = [
LanguageOption("Arabic", "ar_AR"),
LanguageOption("Deutsch", "de_DE"),
LanguageOption("English", "en_XX"),
LanguageOption("Spanish", "es_XX"),
LanguageOption("French", "fr_XX"),
LanguageOption("Hindi", "hi_IN"),
LanguageOption("Italian", "it_IT"),
LanguageOption("Japanese", "ja_XX"),
LanguageOption("Korean", "ko_XX"),
LanguageOption("Portuguese", "pt_XX"),
LanguageOption("Russian", "ru_RU"),
LanguageOption("Chinese", "zh_XX"),
LanguageOption("Afrikaans", "af_ZA"),
LanguageOption("Bengali", "bn_BD"),
LanguageOption("Bosnian", "bs_XX"),
LanguageOption("Catalan", "ca_XX"),
LanguageOption("Czech", "cs_CZ"),
LanguageOption("Danish", "da_XX"),
LanguageOption("Greek", "el_GR"),
LanguageOption("Estonian", "et_EE"),
LanguageOption("Persian", "fa_IR"),
LanguageOption("Finnish", "fi_FI"),
LanguageOption("Gujarati", "gu_IN"),
LanguageOption("Hebrew", "he_IL"),
LanguageOption("Croatian", "hr_HR"),
LanguageOption("Hungarian", "hu_HU"),
LanguageOption("Indonesian", "id_ID"),
LanguageOption("Icelandic", "is_IS"),
LanguageOption("Javanese", "jv_XX"),
LanguageOption("Georgian", "ka_GE"),
LanguageOption("Kazakh", "kk_XX"),
LanguageOption("Khmer", "km_KH"),
LanguageOption("Kannada", "kn_IN"),
LanguageOption("Korean", "ko_KR"),
LanguageOption("Lao", "lo_LA"),
LanguageOption("Lithuanian", "lt_LT"),
LanguageOption("Latvian", "lv_LV"),
LanguageOption("Macedonian", "mk_MK"),
LanguageOption("Malayalam", "ml_IN"),
LanguageOption("Marathi", "mr_IN"),
LanguageOption("Malay", "ms_MY"),
LanguageOption("Nepali", "ne_NP"),
LanguageOption("Dutch", "nl_XX"),
LanguageOption("Norwegian", "no_XX"),
LanguageOption("Polish", "pl_XX"),
LanguageOption("Romanian", "ro_RO"),
LanguageOption("Sinhala", "si_LK"),
LanguageOption("Slovak", "sk_SK"),
LanguageOption("Slovenian", "sl_SI"),
LanguageOption("Albanian", "sq_AL"),
LanguageOption("Serbian", "sr_XX"),
LanguageOption("Swedish", "sv_XX"),
LanguageOption("Swahili", "sw_TZ"),
LanguageOption("Tamil", "ta_IN"),
LanguageOption("Telugu", "te_IN"),
LanguageOption("Tamil ", "ta_IN"),
LanguageOption("Telugu", "te_IN"),
LanguageOption("Thai", "th_TH"),
LanguageOption("Tagalog", "tl_PH"),
LanguageOption("Turkish", "tr_TR"),
LanguageOption("Ukrainian", "uk_UA"),
LanguageOption("Urdu", "ur_PK"),
LanguageOption("Vietnamese", "vi_VN"),
LanguageOption("Waray", "war_PH"),
LanguageOption("Cantonese", "yue_XX"),
LanguageOption("Chinese", "zh_CN"),
LanguageOption("Chinese", "zh_TW"),
]
language_options = {
'Afrikaans': 'af_ZA',
'Arabic': 'ar_AR',
'Azerbaijani (Azerbaijan)': 'az_AZ',
'Bengali (India)': 'bn_IN',
'Burmese (Myanmar)': 'my_MM',
'Chinese': 'zh_CN',
'Croatian': 'hr_HR',
'Czech': 'cs_CZ',
'Deutsch': 'de_DE',
"Dutch": "nl_XX",
'English': 'en_XX',
'Estonian': 'et_EE',
'Finnish': 'fi_FI',
'French': 'fr_XX',
'Galician (Spain)': 'gl_ES',
'Georgian': 'ka_GE',
'Gujarati': 'gu_IN',
'Hebrew': 'he_IL',
'Hindi': 'hi_IN',
'Indonesian': 'id_ID',
'Italian': 'it_IT',
'Japanese': 'ja_XX',
'Kazakh (Kazakhstan)': 'kk_KZ',
'Khmer': 'km_KH',
'Korean': 'ko_KR',
'Latvian': 'lv_LV',
'Lithuanian': 'lt_LT',
'Macedonian': 'mk_MK',
'Malayalam': 'ml_IN',
'Marathi': 'mr_IN',
'Mongolian (Mongolia)': 'mn_MN',
'Nepali': 'ne_NP',
'Pashto (Afghanistan)': 'ps_AF',
'Persian': 'fa_IR',
'Polish (Poland)': 'pl_PL',
'Portuguese': 'pt_XX',
'Romanian': 'ro_RO',
'Russian': 'ru_RU',
'Sinhala': 'si_LK',
'Slovenian': 'sl_SI',
'Spanish': 'es_XX',
'Swahili (Kenya)': 'sw_KE',
'Swedish (Sweden)': 'sv_SE',
'Tamil': 'ta_IN',
'Telugu': 'te_IN',
"Tagalog": "tl_XX",
'Thai': 'th_TH',
'Turkish': 'tr_TR',
'Ukrainian': 'uk_UA',
'Urdu': 'ur_PK',
'Vietnamese': 'vi_VN',
'Xhosa (South Africa)': 'xh_ZA'
}
def remove_unnecessary_spaces(text):
"""Removes unnecessary spaces between characters."""
@ -236,6 +134,7 @@ def remove_unnecessary_spaces(text):
replacement = r")++"
return re.sub(pattern, replacement, text)
def correct_translation_format(original_text, translated_text):
original_parts = original_text.split('++')
translated_parts = translated_text.split('++')
@ -258,6 +157,7 @@ def correct_translation_format(original_text, translated_text):
corrected_text = '++'.join(corrected_parts)
return corrected_text
def extract_plus_positions(text):
"""
Given a string of text, extracts the positions of all sequences of one or more '+' characters.
@ -338,6 +238,7 @@ def match_pluses(original_text, translated_text):
output = "".join(out_vals)
return output
def post_process_prompt(original, translated):
"""Applies post-processing to the translated prompt such as removing unnecessary spaces and extra plus signs."""
clean_prompt = remove_unnecessary_spaces(translated)
@ -345,6 +246,7 @@ def post_process_prompt(original, translated):
# clean_prompt = remove_extra_plus(clean_prompt)
return clean_prompt
class Script(scripts.Script):
def __init__(self) -> None:
"""Initializes the Script class and sets the default value for enable_translation attribute."""
@ -376,17 +278,15 @@ class Script(scripts.Script):
Also, sets the visibility of the language dropdown to True."""
self.is_negative_translate_active = negative_translate_active
def ui(self, is_img2img):
"""Sets up the user interface of the script."""
self.is_active = False
self.current_axis_options = [x for x in language_options]
with gr.Row():
with gr.Column(scale=19):
with gr.Accordion("Prompt Translator", open=False):
with gr.Accordion("Help", open=False):
md = gr.Markdown("""
gr.Markdown("""
# Description
This script translates your prompt from another language to english before generating the image allowing you to write prompts in your native language.
# How to use
@ -403,12 +303,13 @@ class Script(scripts.Script):
self.enable_translation.value = False
self.language = gr.Dropdown(
label="Source language",
choices=[x.label for x in self.current_axis_options],
value="Français",
type="index",
choices=sorted(language_options.keys()),
value="Chinese",
type="value",
elem_id=self.elem_id("x_type")
)
self.output=gr.Markdown(value="After enabling translation, please Wait until I am ready", visible=True)
self.output = gr.Markdown(value="After enabling translation, please Wait until I am ready",
visible=True)
self.enable_translation.change(
self.set_active,
[self.enable_translation],
@ -445,9 +346,9 @@ class Script(scripts.Script):
for original_prompt in original_prompts:
if previous_prompt != original_prompt:
print(f"Translating prompt to English from {language_options[language].label}")
print(f"Translating prompt to English from {language}")
print(f"Initial prompt:{original_prompt}")
ln_code = language_options[language].language_code
ln_code = language_options[language]
translated_prompt = self.translator.translate(original_prompt, ln_code, "en_XX")
translated_prompt = post_process_prompt(original_prompt, translated_prompt)
print(f"Translated prompt:{translated_prompt}")
@ -458,22 +359,20 @@ class Script(scripts.Script):
else:
translated_prompts.append(previous_translated_prompt)
if p.negative_prompt != '' and self.is_negative_translate_active:
previous_negative_prompt = ""
previous_translated_negative_prompt = ""
translated_negative_prompts = []
for negative_prompt in original_negative_prompts:
if previous_negative_prompt != negative_prompt:
print(f"Translating negative prompt to English from {language_options[language].label}")
print(f"Translating negative prompt to English from {language}")
print(f"Initial negative prompt:{negative_prompt}")
ln_code = language_options[language].language_code
ln_code = language_options[language]
translated_negative_prompt = self.translator.translate(negative_prompt, ln_code, "en_XX")
translated_negative_prompt = post_process_prompt(negative_prompt, translated_negative_prompt)
print(f"Translated negative prompt:{translated_negative_prompt}")
translated_negative_prompts.append(translated_negative_prompt)
previous_negative_prompt = negative_prompt
previous_translated_negative_prompt = translated_negative_prompt
else: