feat: add job status (#42)

* feat: add job status

* fix lint
pull/52/head
Maiko Sinkyaet Tan 2023-01-28 00:13:20 +08:00 committed by GitHub
parent 8a81c8baf0
commit da0e8e8fe9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 27 additions and 3 deletions

View File

@ -1,5 +1,6 @@
import asyncio
import base64
from enum import Enum
import io
import json
from os import path
@ -129,6 +130,17 @@ class State:
}
class JobStatus(Enum):
PENDING = "pending"
RUNNING = "running"
GENERATED = "generated"
SUBMITTING = "submitting"
UPLOADED = "uploaded"
SUBMITTED = "submitted"
DONE = "done"
ERROR = "error"
class HordeJob:
retry_interval: int = 1
@ -157,6 +169,8 @@ class HordeJob:
source_mask: Optional[Image.Image] = None,
r2_upload: Optional[str] = None,
):
self.status: JobStatus = JobStatus.PENDING
self.session = session
self.id = id
self.model = model
self.prompt = prompt
@ -182,7 +196,9 @@ class HordeJob:
self.source_mask = source_mask
self.r2_upload = r2_upload
async def submit(self, image: Image.Image, session: aiohttp.ClientSession):
async def submit(self, image: Image.Image):
self.status = JobStatus.SUBMITTING
bytesio = io.BytesIO()
image.save(bytesio, format="WebP", quality=95)
@ -199,6 +215,8 @@ class HordeJob:
continue
generation = "R2"
self.status = JobStatus.UPLOADED
else:
generation = base64.b64encode(bytesio.getvalue()).decode("utf8")
@ -211,7 +229,7 @@ class HordeJob:
attempts = 10
while attempts > 0:
try:
r = await session.post("/api/v2/generate/submit", json=post_data)
r = await self.session.post("/api/v2/generate/submit", json=post_data)
try:
res = await r.json()
@ -229,7 +247,11 @@ class HordeJob:
continue
if r.ok:
return res.get("reward", None)
self.status = JobStatus.SUBMITTED
reward = res.get("reward", None)
if reward:
self.status = JobStatus.DONE
return reward
else:
print(
"Failed to submit job with status code"
@ -245,6 +267,8 @@ class HordeJob:
await asyncio.sleep(self.retry_interval)
continue
self.status = JobStatus.ERROR
@classmethod
async def get(
cls,