248 lines
8.8 KiB
Python
248 lines
8.8 KiB
Python
# MBartTranslator :
|
|
# Author : ParisNeo
|
|
# Description : This script translates Stable diffusion prompt from one of the 50 languages supported by MBART
|
|
# It uses MBartTranslator class that provides a simple interface for translating text using the MBart language model.
|
|
|
|
import modules.scripts as scripts
|
|
import gradio as gr
|
|
from modules.shared import opts
|
|
|
|
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
|
|
|
|
|
|
class MBartTranslator:
|
|
"""MBartTranslator class provides a simple interface for translating text using the MBart language model.
|
|
|
|
The class can translate between 50 languages and is based on the "facebook/mbart-large-50-many-to-one-mmt"
|
|
pre-trained MBart model. However, it is possible to use a different MBart model by specifying its name.
|
|
|
|
Attributes:
|
|
model (MBartForConditionalGeneration): The MBart language model.
|
|
tokenizer (MBart50TokenizerFast): The MBart tokenizer.
|
|
"""
|
|
|
|
def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None):
|
|
|
|
self.supported_languages = [
|
|
"ar_AR",
|
|
"de_DE",
|
|
"en_XX",
|
|
"es_XX",
|
|
"fr_XX",
|
|
"hi_IN",
|
|
"it_IT",
|
|
"ja_XX",
|
|
"ko_XX",
|
|
"pt_XX",
|
|
"ru_XX",
|
|
"zh_XX",
|
|
"af_ZA",
|
|
"bn_BD",
|
|
"bs_XX",
|
|
"ca_XX",
|
|
"cs_CZ",
|
|
"da_XX",
|
|
"el_GR",
|
|
"et_EE",
|
|
"fa_IR",
|
|
"fi_FI",
|
|
"gu_IN",
|
|
"he_IL",
|
|
"hi_XX",
|
|
"hr_HR",
|
|
"hu_HU",
|
|
"id_ID",
|
|
"is_IS",
|
|
"ja_XX",
|
|
"jv_XX",
|
|
"ka_GE",
|
|
"kk_XX",
|
|
"km_KH",
|
|
"kn_IN",
|
|
"ko_KR",
|
|
"lo_LA",
|
|
"lt_LT",
|
|
"lv_LV",
|
|
"mk_MK",
|
|
"ml_IN",
|
|
"mr_IN",
|
|
"ms_MY",
|
|
"ne_NP",
|
|
"nl_XX",
|
|
"no_XX",
|
|
"pl_XX",
|
|
"ro_RO",
|
|
"si_LK",
|
|
"sk_SK",
|
|
"sl_SI",
|
|
"sq_AL",
|
|
"sr_XX",
|
|
"sv_XX",
|
|
"sw_TZ",
|
|
"ta_IN",
|
|
"te_IN",
|
|
"th_TH",
|
|
"tl_PH",
|
|
"tr_TR",
|
|
"uk_UA",
|
|
"ur_PK",
|
|
"vi_VN",
|
|
"war_PH",
|
|
"yue_XX",
|
|
"zh_CN",
|
|
"zh_TW",
|
|
]
|
|
print("Building translator")
|
|
print("Loading generator")
|
|
self.model = MBartForConditionalGeneration.from_pretrained(model_name)
|
|
print("Loading tokenizer")
|
|
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang)
|
|
print("Translator ready")
|
|
|
|
def translate(self, text: str, input_language: str, output_language: str) -> str:
|
|
"""Translate the given text from the input language to the output language.
|
|
|
|
Args:
|
|
text (str): The text to translate.
|
|
input_language (str): The input language code (e.g. "hi_IN" for Hindi).
|
|
output_language (str): The output language code (e.g. "en_US" for English).
|
|
|
|
Returns:
|
|
str: The translated text.
|
|
"""
|
|
if input_language not in self.supported_languages:
|
|
raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}")
|
|
if output_language not in self.supported_languages:
|
|
raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}")
|
|
|
|
self.tokenizer.src_lang = input_language
|
|
encoded_input = self.tokenizer(text, return_tensors="pt")
|
|
generated_tokens = self.model.generate(
|
|
**encoded_input, forced_bos_token_id=self.tokenizer.lang_code_to_id[output_language]
|
|
)
|
|
translated_text = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
|
|
|
return translated_text[0]
|
|
|
|
|
|
|
|
class AxisOption:
|
|
def __init__(self, label, language_code):
|
|
self.label = label
|
|
self.language_code = language_code
|
|
|
|
|
|
|
|
language_options = [
|
|
AxisOption("English", "en_XX"),
|
|
AxisOption("عربية", "ar_AR"),
|
|
AxisOption("Deutsch", "de_DE"),
|
|
AxisOption("Español", "es_XX"),
|
|
AxisOption("Français", "fr_XX"),
|
|
AxisOption("हिन्दी", "hi_IN"),
|
|
AxisOption("Italiano", "it_IT"),
|
|
AxisOption("日本語", "ja_XX"),
|
|
AxisOption("한국어", "ko_XX"),
|
|
AxisOption("Português", "pt_XX"),
|
|
AxisOption("Русский", "ru_XX"),
|
|
AxisOption("中文", "zh_XX"),
|
|
AxisOption("Afrikaans", "af_ZA"),
|
|
AxisOption("বাংলা", "bn_BD"),
|
|
AxisOption("Bosanski", "bs_XX"),
|
|
AxisOption("Català", "ca_XX"),
|
|
AxisOption("Čeština", "cs_CZ"),
|
|
AxisOption("Dansk", "da_XX"),
|
|
AxisOption("Ελληνικά", "el_GR"),
|
|
AxisOption("Eesti", "et_EE"),
|
|
AxisOption("فارسی", "fa_IR"),
|
|
AxisOption("Suomi", "fi_FI"),
|
|
AxisOption("ગુજરાતી", "gu_IN"),
|
|
AxisOption("עברית", "he_IL"),
|
|
AxisOption("हिन्दी", "hi_XX"),
|
|
AxisOption("Hrvatski", "hr_HR"),
|
|
AxisOption("Magyar", "hu_HU"),
|
|
AxisOption("Bahasa Indonesia", "id_ID"),
|
|
AxisOption("Íslenska", "is_IS"),
|
|
AxisOption("日本語", "ja_XX"),
|
|
AxisOption("Javanese", "jv_XX"),
|
|
AxisOption("ქართული", "ka_GE"),
|
|
AxisOption("Қазақ", "kk_XX"),
|
|
AxisOption("ខ្មែរ", "km_KH"),
|
|
AxisOption("ಕನ್ನಡ", "kn_IN"),
|
|
AxisOption("한국어", "ko_KR"),
|
|
AxisOption("ລາວ", "lo_LA"),
|
|
AxisOption("Lietuvių", "lt_LT"),
|
|
AxisOption("Latviešu", "lv_LV"),
|
|
AxisOption("Македонски", "mk_MK"),
|
|
AxisOption("മലയാളം", "ml_IN"),
|
|
AxisOption("मराठी", "mr_IN"),
|
|
AxisOption("Bahasa Melayu", "ms_MY"),
|
|
AxisOption("नेपाली", "ne_NP"),
|
|
AxisOption("Nederlands", "nl_XX"),
|
|
AxisOption("Norsk", "no_XX"),
|
|
AxisOption("Polski", "pl_XX"),
|
|
AxisOption("Română", "ro_RO"),
|
|
AxisOption("සිංහල", "si_LK"),
|
|
AxisOption("Slovenčina", "sk_SK"),
|
|
AxisOption("Slovenščina", "sl_SI"),
|
|
AxisOption("Shqip", "sq_AL"),
|
|
]
|
|
|
|
|
|
class Script(scripts.Script):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.enable_translation=False
|
|
|
|
def title(self):
|
|
return "Translate prompt to english"
|
|
|
|
def show(self, is_img2img):
|
|
# if is_img2img:
|
|
# return False
|
|
return scripts.AlwaysVisible
|
|
|
|
def set_active(self, active):
|
|
self.is_active=active
|
|
if not hasattr(self, "translator"):
|
|
self.translator = MBartTranslator()
|
|
return self.language.update(visible=True)
|
|
|
|
|
|
|
|
def ui(self, is_img2img):
|
|
self.is_active=False
|
|
self.current_axis_options = [x for x in language_options]
|
|
|
|
with gr.Row():
|
|
with gr.Column(scale=19):
|
|
with gr.Accordion("Prompt Translator",open=False):
|
|
with gr.Accordion("Help",open=False):
|
|
md = gr.Markdown("""
|
|
# Description
|
|
This script translates your prompt from another language to english before generating the image allowing you to write prompts in your native language.
|
|
# How to use
|
|
Select Enable translation and wait until you the list of languages show up.
|
|
Once the languages are shown, select the prompt language, write the prompt in the prompt field then press generate. The script will translate the prompt and generate the text.
|
|
# Note
|
|
First time you enable the script, it may take a long time (around a minute), but once loaded, it will be faster.
|
|
""")
|
|
with gr.Column():
|
|
self.enable_translation = gr.Checkbox(label="Enable translation")
|
|
self.enable_translation.value=False
|
|
self.language = gr.Dropdown(label="Source language", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
|
|
self.enable_translation.change(self.set_active,[self.enable_translation],[self.language])
|
|
|
|
self.language.visible=False
|
|
return [self.language]
|
|
|
|
def process(self, p, language, **kwargs):
|
|
print(f"{self.enable_translation.value}")
|
|
if hasattr(self, "translator") and self.is_active:
|
|
print(f"Translating to English from {language_options[language].label}")
|
|
print(f"Initial prompt:{p.prompt}")
|
|
ln_code = language_options[language].language_code
|
|
p.prompt = self.translator.translate(p.prompt, ln_code, "en_XX")
|
|
print(f"Translated prompt:{p.prompt}")
|
|
|