the model is now saved locally

pull/10/head
ParisNeo 2023-03-31 22:46:27 +02:00
parent 0d43448b82
commit eee6891281
1 changed files with 15 additions and 7 deletions

View File

@ -8,10 +8,11 @@ import gradio as gr
from modules.shared import opts
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
import re
import os
# The directory to store the models
cache_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models')
class MBartTranslator:
"""MBartTranslator class provides a simple interface for translating text using the MBart language model.
@ -97,9 +98,9 @@ class MBartTranslator:
]
print("Building translator")
print("Loading generator (this may take few minutes the first time as I need to download the model)")
self.model = MBartForConditionalGeneration.from_pretrained(model_name)
self.model = MBartForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir)
print("Loading tokenizer")
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang)
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang, cache_dir=cache_dir)
print("Translator is ready")
def translate(self, text: str, input_language: str, output_language: str) -> str:
@ -354,7 +355,8 @@ class Script(scripts.Script):
self.is_active=active
if not hasattr(self, "translator"):
self.translator = MBartTranslator()
return "ready"
return "ready", self.translate_negative_prompt.update(visible=True), self.language.update(visible=True)
def ui(self, is_img2img):
"""Sets up the user interface of the script."""
@ -386,8 +388,14 @@ class Script(scripts.Script):
elem_id=self.elem_id("x_type")
)
self.output=gr.Label("After enabling translation, please Wait until I am ready")
self.enable_translation.change(self.set_active,[self.enable_translation], [self.output], show_progress=True)
self.enable_translation.change(
self.set_active,
[self.enable_translation],
[self.translate_negative_prompt, self.language,self.output],
show_progress=True
)
self.translate_negative_prompt.visible=False
self.language.visible=False
return [self.language]
def get_prompts(self, p):