From 622827aab5e6de28591e86c50ca02d15f7875e19 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 11 Oct 2024 06:11:10 -0500 Subject: [PATCH] refactoring --- scripts/spartan/pmodels.py | 6 +++--- scripts/spartan/shared.py | 2 +- scripts/spartan/world.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/spartan/pmodels.py b/scripts/spartan/pmodels.py index 10a5e0a..48fe7a6 100644 --- a/scripts/spartan/pmodels.py +++ b/scripts/spartan/pmodels.py @@ -29,13 +29,13 @@ class Worker_Model(BaseModel): default=False ) state: Optional[Any] = Field(default=1, description="The last known state of this worker") - user: Optional[str] = Field(description="The username to be used when authenticating with this worker") - password: Optional[str] = Field(description="The password to be used when authenticating with this worker") + user: Optional[str] = Field(description="The username to be used when authenticating with this worker", default=None) + password: Optional[str] = Field(description="The password to be used when authenticating with this worker", default=None) pixel_cap: Optional[int] = Field(default=-1, description="Max amount of pixels to allow one worker to handle at the same time. -1 means there is no limit") class ConfigModel(BaseModel): workers: List[Dict[str, Worker_Model]] - benchmark_payload: Dict = Field( + benchmark_payload: Benchmark_Payload = Field( default=Benchmark_Payload, description='the payload used when benchmarking a node' ) diff --git a/scripts/spartan/shared.py b/scripts/spartan/shared.py index 5400464..ead9f65 100644 --- a/scripts/spartan/shared.py +++ b/scripts/spartan/shared.py @@ -65,7 +65,7 @@ samples = 3 # number of times to benchmark worker after warmup benchmarks are c class BenchmarkPayload(BaseModel): - validate_assignment = True + # validate_assignment = True prompt: str = Field(default="A herd of cows grazing at the bottom of a sunny valley") negative_prompt: str = Field(default="") steps: int = Field(default=20) diff --git a/scripts/spartan/world.py b/scripts/spartan/world.py index 9ff7ceb..3b432b9 100644 --- a/scripts/spartan/world.py +++ b/scripts/spartan/world.py @@ -674,7 +674,7 @@ class World: self.add_worker(**fields) - sh.benchmark_payload = Benchmark_Payload(**config.benchmark_payload) + sh.benchmark_payload = Benchmark_Payload(**config.benchmark_payload.dict()) self.job_timeout = config.job_timeout self.enabled = config.enabled self.enabled_i2i = config.enabled_i2i @@ -699,7 +699,7 @@ class World: ) with open(self.config_path, 'w+') as config_file: - config_file.write(config.json(indent=3)) + config_file.write(config.model_dump_json(indent=3)) logger.debug(f"config saved") def ping_remotes(self, indiscriminate: bool = False):