From 4cd4b43ee5d417967566e3527c6f3795d809a7f5 Mon Sep 17 00:00:00 2001 From: ParisNeo Date: Tue, 28 Mar 2023 00:17:33 +0200 Subject: [PATCH] First version --- .vscode/settings.json | 3 + README.md | 16 ++- scripts/main.py | 247 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 265 insertions(+), 1 deletion(-) create mode 100644 .vscode/settings.json create mode 100644 scripts/main.py diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..5d7d730 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "ros.distro": "noetic" +} \ No newline at end of file diff --git a/README.md b/README.md index 4ffaaad..c1e9e6e 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,16 @@ # prompt_translator -A stable diffusion extension for translating prompts from 50 languages. The objective is to give users the possibility to use their own language to perform text prompting. +Prompt_translator is an extension for Stable Diffusion Web UI (sd_webui), which adds an automatic translation tool to the Gradio UI. This tool allows users to generate images based on prompts written in 50 different languages. + +## Installation +To install prompt_translator, clone the repository or extract the zip file to the extensions folder of the sd_webui mother application. + +## Usage +After installing prompt_translator, a new entry will be added to the Gradio UI. To use the automatic translation tool, click the "Load Translation Model" button to load the translation model. Once the model is loaded, a dropdown UI will be displayed, where the user can select the source language of their prompt. + +The user can then write their prompt in the desired language and press the "Generate" button to generate the image. The prompt will automatically be translated to English, and the resulting image will look as described in the text. + +## License +This project is licensed under the MIT license. + +## Contributing +Contributions to prompt_translator are welcome! If you find a bug or have an idea for a new feature, please create an issue on the project's GitHub page. If you'd like to contribute code, please fork the repository, make your changes, and submit a pull request. \ No newline at end of file diff --git a/scripts/main.py b/scripts/main.py new file mode 100644 index 0000000..1e42daa --- /dev/null +++ b/scripts/main.py @@ -0,0 +1,247 @@ +# MBartTranslator : +# Author : ParisNeo +# Description : This script translates Stable diffusion prompt from one of the 50 languages supported by MBART +# It uses MBartTranslator class that provides a simple interface for translating text using the MBart language model. + +import modules.scripts as scripts +import gradio as gr +from modules.shared import opts + +from transformers import MBart50TokenizerFast, MBartForConditionalGeneration + + +class MBartTranslator: + """MBartTranslator class provides a simple interface for translating text using the MBart language model. + + The class can translate between 50 languages and is based on the "facebook/mbart-large-50-many-to-one-mmt" + pre-trained MBart model. However, it is possible to use a different MBart model by specifying its name. + + Attributes: + model (MBartForConditionalGeneration): The MBart language model. + tokenizer (MBart50TokenizerFast): The MBart tokenizer. + """ + + def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None): + + self.supported_languages = [ + "ar_AR", + "de_DE", + "en_XX", + "es_XX", + "fr_XX", + "hi_IN", + "it_IT", + "ja_XX", + "ko_XX", + "pt_XX", + "ru_XX", + "zh_XX", + "af_ZA", + "bn_BD", + "bs_XX", + "ca_XX", + "cs_CZ", + "da_XX", + "el_GR", + "et_EE", + "fa_IR", + "fi_FI", + "gu_IN", + "he_IL", + "hi_XX", + "hr_HR", + "hu_HU", + "id_ID", + "is_IS", + "ja_XX", + "jv_XX", + "ka_GE", + "kk_XX", + "km_KH", + "kn_IN", + "ko_KR", + "lo_LA", + "lt_LT", + "lv_LV", + "mk_MK", + "ml_IN", + "mr_IN", + "ms_MY", + "ne_NP", + "nl_XX", + "no_XX", + "pl_XX", + "ro_RO", + "si_LK", + "sk_SK", + "sl_SI", + "sq_AL", + "sr_XX", + "sv_XX", + "sw_TZ", + "ta_IN", + "te_IN", + "th_TH", + "tl_PH", + "tr_TR", + "uk_UA", + "ur_PK", + "vi_VN", + "war_PH", + "yue_XX", + "zh_CN", + "zh_TW", + ] + print("Building translator") + print("Loading generator") + self.model = MBartForConditionalGeneration.from_pretrained(model_name) + print("Loading tokenizer") + self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang) + print("Translator ready") + + def translate(self, text: str, input_language: str, output_language: str) -> str: + """Translate the given text from the input language to the output language. + + Args: + text (str): The text to translate. + input_language (str): The input language code (e.g. "hi_IN" for Hindi). + output_language (str): The output language code (e.g. "en_US" for English). + + Returns: + str: The translated text. + """ + if input_language not in self.supported_languages: + raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}") + if output_language not in self.supported_languages: + raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}") + + self.tokenizer.src_lang = input_language + encoded_input = self.tokenizer(text, return_tensors="pt") + generated_tokens = self.model.generate( + **encoded_input, forced_bos_token_id=self.tokenizer.lang_code_to_id[output_language] + ) + translated_text = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + + return translated_text[0] + + + +class AxisOption: + def __init__(self, label, language_code): + self.label = label + self.language_code = language_code + + + +language_options = [ + AxisOption("English", "en_XX"), + AxisOption("عربية", "ar_AR"), + AxisOption("Deutsch", "de_DE"), + AxisOption("Español", "es_XX"), + AxisOption("Français", "fr_XX"), + AxisOption("हिन्दी", "hi_IN"), + AxisOption("Italiano", "it_IT"), + AxisOption("日本語", "ja_XX"), + AxisOption("한국어", "ko_XX"), + AxisOption("Português", "pt_XX"), + AxisOption("Русский", "ru_XX"), + AxisOption("中文", "zh_XX"), + AxisOption("Afrikaans", "af_ZA"), + AxisOption("বাংলা", "bn_BD"), + AxisOption("Bosanski", "bs_XX"), + AxisOption("Català", "ca_XX"), + AxisOption("Čeština", "cs_CZ"), + AxisOption("Dansk", "da_XX"), + AxisOption("Ελληνικά", "el_GR"), + AxisOption("Eesti", "et_EE"), + AxisOption("فارسی", "fa_IR"), + AxisOption("Suomi", "fi_FI"), + AxisOption("ગુજરાતી", "gu_IN"), + AxisOption("עברית", "he_IL"), + AxisOption("हिन्दी", "hi_XX"), + AxisOption("Hrvatski", "hr_HR"), + AxisOption("Magyar", "hu_HU"), + AxisOption("Bahasa Indonesia", "id_ID"), + AxisOption("Íslenska", "is_IS"), + AxisOption("日本語", "ja_XX"), + AxisOption("Javanese", "jv_XX"), + AxisOption("ქართული", "ka_GE"), + AxisOption("Қазақ", "kk_XX"), + AxisOption("ខ្មែរ", "km_KH"), + AxisOption("ಕನ್ನಡ", "kn_IN"), + AxisOption("한국어", "ko_KR"), + AxisOption("ລາວ", "lo_LA"), + AxisOption("Lietuvių", "lt_LT"), + AxisOption("Latviešu", "lv_LV"), + AxisOption("Македонски", "mk_MK"), + AxisOption("മലയാളം", "ml_IN"), + AxisOption("मराठी", "mr_IN"), + AxisOption("Bahasa Melayu", "ms_MY"), + AxisOption("नेपाली", "ne_NP"), + AxisOption("Nederlands", "nl_XX"), + AxisOption("Norsk", "no_XX"), + AxisOption("Polski", "pl_XX"), + AxisOption("Română", "ro_RO"), + AxisOption("සිංහල", "si_LK"), + AxisOption("Slovenčina", "sk_SK"), + AxisOption("Slovenščina", "sl_SI"), + AxisOption("Shqip", "sq_AL"), +] + + +class Script(scripts.Script): + def __init__(self) -> None: + super().__init__() + self.enable_translation=False + + def title(self): + return "Translate prompt to english" + + def show(self, is_img2img): + # if is_img2img: + # return False + return scripts.AlwaysVisible + + def set_active(self, active): + self.is_active=active + if not hasattr(self, "translator"): + self.translator = MBartTranslator() + return self.language.update(visible=True) + + + + def ui(self, is_img2img): + self.is_active=False + self.current_axis_options = [x for x in language_options] + + with gr.Row(): + with gr.Column(scale=19): + with gr.Accordion("Prompt Translator",open=False): + with gr.Accordion("Help",open=False): + md = gr.Markdown(""" + # Description + This script translates your prompt from another language to english before generating the image allowing you to write prompts in your native language. + # How to use + Select Enable translation and wait until you the list of languages show up. + Once the languages are shown, select the prompt language, write the prompt in the prompt field then press generate. The script will translate the prompt and generate the text. + # Note + First time you enable the script, it may take a long time (around a minute), but once loaded, it will be faster. + """) + with gr.Column(): + self.enable_translation = gr.Checkbox(label="Enable translation") + self.enable_translation.value=False + self.language = gr.Dropdown(label="Source language", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) + self.enable_translation.change(self.set_active,[self.enable_translation],[self.language]) + + self.language.visible=False + return [self.language] + + def before_process_batch(self, p, language, **kwargs): + print(f"{self.enable_translation.value}") + if hasattr(self, "translator") and self.is_active: + print(f"Translating to English from {language_options[language].label}") + print(f"Initial prompt:{p.prompt}") + ln_code = language_options[language].language_code + p.prompt = self.translator.translate(p.prompt, ln_code, "en_XX") + print(f"Translated prompt:{p.prompt}") +