parent
8a81c8baf0
commit
da0e8e8fe9
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import base64
|
||||
from enum import Enum
|
||||
import io
|
||||
import json
|
||||
from os import path
|
||||
|
|
@ -129,6 +130,17 @@ class State:
|
|||
}
|
||||
|
||||
|
||||
class JobStatus(Enum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
GENERATED = "generated"
|
||||
SUBMITTING = "submitting"
|
||||
UPLOADED = "uploaded"
|
||||
SUBMITTED = "submitted"
|
||||
DONE = "done"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class HordeJob:
|
||||
retry_interval: int = 1
|
||||
|
||||
|
|
@ -157,6 +169,8 @@ class HordeJob:
|
|||
source_mask: Optional[Image.Image] = None,
|
||||
r2_upload: Optional[str] = None,
|
||||
):
|
||||
self.status: JobStatus = JobStatus.PENDING
|
||||
self.session = session
|
||||
self.id = id
|
||||
self.model = model
|
||||
self.prompt = prompt
|
||||
|
|
@ -182,7 +196,9 @@ class HordeJob:
|
|||
self.source_mask = source_mask
|
||||
self.r2_upload = r2_upload
|
||||
|
||||
async def submit(self, image: Image.Image, session: aiohttp.ClientSession):
|
||||
async def submit(self, image: Image.Image):
|
||||
self.status = JobStatus.SUBMITTING
|
||||
|
||||
bytesio = io.BytesIO()
|
||||
image.save(bytesio, format="WebP", quality=95)
|
||||
|
||||
|
|
@ -199,6 +215,8 @@ class HordeJob:
|
|||
continue
|
||||
generation = "R2"
|
||||
|
||||
self.status = JobStatus.UPLOADED
|
||||
|
||||
else:
|
||||
generation = base64.b64encode(bytesio.getvalue()).decode("utf8")
|
||||
|
||||
|
|
@ -211,7 +229,7 @@ class HordeJob:
|
|||
attempts = 10
|
||||
while attempts > 0:
|
||||
try:
|
||||
r = await session.post("/api/v2/generate/submit", json=post_data)
|
||||
r = await self.session.post("/api/v2/generate/submit", json=post_data)
|
||||
|
||||
try:
|
||||
res = await r.json()
|
||||
|
|
@ -229,7 +247,11 @@ class HordeJob:
|
|||
continue
|
||||
|
||||
if r.ok:
|
||||
return res.get("reward", None)
|
||||
self.status = JobStatus.SUBMITTED
|
||||
reward = res.get("reward", None)
|
||||
if reward:
|
||||
self.status = JobStatus.DONE
|
||||
return reward
|
||||
else:
|
||||
print(
|
||||
"Failed to submit job with status code"
|
||||
|
|
@ -245,6 +267,8 @@ class HordeJob:
|
|||
await asyncio.sleep(self.retry_interval)
|
||||
continue
|
||||
|
||||
self.status = JobStatus.ERROR
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls,
|
||||
|
|
|
|||
Loading…
Reference in New Issue