the model is now saved locally
parent
0d43448b82
commit
eee6891281
|
|
@ -8,10 +8,11 @@ import gradio as gr
|
|||
from modules.shared import opts
|
||||
|
||||
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
|
||||
|
||||
import re
|
||||
import os
|
||||
|
||||
|
||||
# The directory to store the models
|
||||
cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models')
|
||||
|
||||
class MBartTranslator:
|
||||
"""MBartTranslator class provides a simple interface for translating text using the MBart language model.
|
||||
|
|
@ -97,9 +98,9 @@ class MBartTranslator:
|
|||
]
|
||||
print("Building translator")
|
||||
print("Loading generator (this may take few minutes the first time as I need to download the model)")
|
||||
self.model = MBartForConditionalGeneration.from_pretrained(model_name)
|
||||
self.model = MBartForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
|
||||
print("Loading tokenizer")
|
||||
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang)
|
||||
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang, cache_dir=cache_dir)
|
||||
print("Translator is ready")
|
||||
|
||||
def translate(self, text: str, input_language: str, output_language: str) -> str:
|
||||
|
|
@ -354,7 +355,8 @@ class Script(scripts.Script):
|
|||
self.is_active=active
|
||||
if not hasattr(self, "translator"):
|
||||
self.translator = MBartTranslator()
|
||||
return "ready"
|
||||
return "ready", self.translate_negative_prompt.update(visible=True), self.language.update(visible=True)
|
||||
|
||||
|
||||
def ui(self, is_img2img):
|
||||
"""Sets up the user interface of the script."""
|
||||
|
|
@ -386,8 +388,14 @@ class Script(scripts.Script):
|
|||
elem_id=self.elem_id("x_type")
|
||||
)
|
||||
self.output=gr.Label("After enabling translation, please Wait until I am ready")
|
||||
self.enable_translation.change(self.set_active,[self.enable_translation], [self.output], show_progress=True)
|
||||
|
||||
self.enable_translation.change(
|
||||
self.set_active,
|
||||
[self.enable_translation],
|
||||
[self.translate_negative_prompt, self.language,self.output],
|
||||
show_progress=True
|
||||
)
|
||||
self.translate_negative_prompt.visible=False
|
||||
self.language.visible=False
|
||||
return [self.language]
|
||||
|
||||
def get_prompts(self, p):
|
||||
|
|
|
|||
Loading…
Reference in New Issue