一波优化

pull/21/head
SpenserCai 2023-08-03 15:11:22 +08:00
parent a365308d3e
commit 2d616fb151
6 changed files with 10 additions and 8 deletions

4
app.py
View File

@ -3,7 +3,7 @@ Author: SpenserCai
Date: 2023-07-28 15:49:52 Date: 2023-07-28 15:49:52
version: version:
LastEditors: SpenserCai LastEditors: SpenserCai
LastEditTime: 2023-08-02 10:12:03 LastEditTime: 2023-08-03 15:04:13
Description: file content Description: file content
''' '''
from deoldify import device from deoldify import device
@ -30,7 +30,7 @@ def image_to_base64(image_path):
return image_b64 return image_b64
def ColorizeImage(base64str, render_factor=50, artistic=False): def ColorizeImage(base64str, render_factor=50, artistic=False):
vis = get_image_colorizer(render_factor=render_factor, artistic=artistic) vis = get_image_colorizer(root_folder=Path("models"),render_factor=render_factor, artistic=artistic)
# 把base64转换成图片 PIL.Image # 把base64转换成图片 PIL.Image
img = Image.open(BytesIO(base64.b64decode(base64str))) img = Image.open(BytesIO(base64.b64decode(base64str)))
print("loaded image") print("loaded image")

View File

@ -157,7 +157,7 @@ class Learner():
wd:Floats=defaults.wd wd:Floats=defaults.wd
train_bn:bool=True train_bn:bool=True
path:str = None path:str = None
model_dir:PathOrStr = 'models' model_dir:PathOrStr = 'deoldify'
callback_fns:Collection[Callable]=None callback_fns:Collection[Callable]=None
callbacks:Collection[Callback]=field(default_factory=list) callbacks:Collection[Callback]=field(default_factory=list)
layer_groups:Collection[nn.Module]=None layer_groups:Collection[nn.Module]=None

View File

@ -3,7 +3,7 @@ Author: SpenserCai
Date: 2023-07-28 14:37:09 Date: 2023-07-28 14:37:09
version: version:
LastEditors: SpenserCai LastEditors: SpenserCai
LastEditTime: 2023-08-03 14:31:04 LastEditTime: 2023-08-03 14:53:27
Description: file content Description: file content
''' '''
import os import os

View File

@ -3,7 +3,7 @@ Author: SpenserCai
Date: 2023-07-28 14:37:40 Date: 2023-07-28 14:37:40
version: version:
LastEditors: SpenserCai LastEditors: SpenserCai
LastEditTime: 2023-08-03 14:41:58 LastEditTime: 2023-08-03 15:06:53
Description: file content Description: file content
''' '''
# DeOldify API # DeOldify API
@ -30,7 +30,7 @@ def deoldify_api(_: gr.Blocks, app: FastAPI):
render_factor: int = Body(35,title="render factor"), render_factor: int = Body(35,title="render factor"),
artistic: bool = Body(False,title="artistic") artistic: bool = Body(False,title="artistic")
): ):
vis = get_image_colorizer(root_folder=Path("models/deoldify"),render_factor=render_factor, artistic=artistic) vis = get_image_colorizer(root_folder=Path("models"),render_factor=render_factor, artistic=artistic)
# 把base64转换成图片 PIL.Image # 把base64转换成图片 PIL.Image
img = Image.open(BytesIO(base64.b64decode(input_image))) img = Image.open(BytesIO(base64.b64decode(input_image)))
outImg = vis.get_transformed_image_from_image(img, render_factor=render_factor) outImg = vis.get_transformed_image_from_image(img, render_factor=render_factor)

View File

@ -3,7 +3,7 @@ Author: SpenserCai
Date: 2023-07-28 14:41:28 Date: 2023-07-28 14:41:28
version: version:
LastEditors: SpenserCai LastEditors: SpenserCai
LastEditTime: 2023-08-03 14:40:50 LastEditTime: 2023-08-03 15:11:07
Description: file content Description: file content
''' '''
# DeOldify UI & Processing # DeOldify UI & Processing
@ -44,13 +44,15 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
} }
def process_image(self, image, render_factor, artistic): def process_image(self, image, render_factor, artistic):
vis = get_image_colorizer(root_folder=Path("models/deoldify"),render_factor=render_factor, artistic=artistic) vis = get_image_colorizer(root_folder=Path("models"),render_factor=render_factor, artistic=artistic)
outImg = vis.get_transformed_image_from_image(image, render_factor=render_factor) outImg = vis.get_transformed_image_from_image(image, render_factor=render_factor)
return outImg return outImg
def process(self, pp: scripts_postprocessing.PostprocessedImage, is_enabled, render_factor, artistic): def process(self, pp: scripts_postprocessing.PostprocessedImage, is_enabled, render_factor, artistic):
if not is_enabled or is_enabled is False: if not is_enabled or is_enabled is False:
return return
print(type(pp.image))
pp.image = self.process_image(pp.image, render_factor, artistic) pp.image = self.process_image(pp.image, render_factor, artistic)
pp.info["deoldify"] = f"render_factor={render_factor}, artistic={artistic}" pp.info["deoldify"] = f"render_factor={render_factor}, artistic={artistic}"