30 lines
890 B
Python
30 lines
890 B
Python
import logging
|
|
import torch
|
|
from annotator.lama.saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule
|
|
|
|
|
|
def get_training_model_class(kind):
|
|
if kind == 'default':
|
|
return DefaultInpaintingTrainingModule
|
|
|
|
raise ValueError(f'Unknown trainer module {kind}')
|
|
|
|
|
|
def make_training_model(config):
|
|
kind = config.training_model.kind
|
|
kwargs = dict(config.training_model)
|
|
kwargs.pop('kind')
|
|
kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp'
|
|
|
|
logging.info(f'Make training model {kind}')
|
|
|
|
cls = get_training_model_class(kind)
|
|
return cls(config, **kwargs)
|
|
|
|
|
|
def load_checkpoint(train_config, path, map_location='cuda', strict=True):
|
|
model = make_training_model(train_config).generator
|
|
state = torch.load(path, map_location=map_location)
|
|
model.load_state_dict(state, strict=strict)
|
|
return model
|