178 lines
8.4 KiB
Python
178 lines
8.4 KiB
Python
from typing import Tuple, Dict, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class BaseAdversarialLoss:
|
|
def pre_generator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
|
generator: nn.Module, discriminator: nn.Module):
|
|
"""
|
|
Prepare for generator step
|
|
:param real_batch: Tensor, a batch of real samples
|
|
:param fake_batch: Tensor, a batch of samples produced by generator
|
|
:param generator:
|
|
:param discriminator:
|
|
:return: None
|
|
"""
|
|
|
|
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
|
generator: nn.Module, discriminator: nn.Module):
|
|
"""
|
|
Prepare for discriminator step
|
|
:param real_batch: Tensor, a batch of real samples
|
|
:param fake_batch: Tensor, a batch of samples produced by generator
|
|
:param generator:
|
|
:param discriminator:
|
|
:return: None
|
|
"""
|
|
|
|
def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
|
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
|
|
mask: Optional[torch.Tensor] = None) \
|
|
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
"""
|
|
Calculate generator loss
|
|
:param real_batch: Tensor, a batch of real samples
|
|
:param fake_batch: Tensor, a batch of samples produced by generator
|
|
:param discr_real_pred: Tensor, discriminator output for real_batch
|
|
:param discr_fake_pred: Tensor, discriminator output for fake_batch
|
|
:param mask: Tensor, actual mask, which was at input of generator when making fake_batch
|
|
:return: total generator loss along with some values that might be interesting to log
|
|
"""
|
|
raise NotImplemented()
|
|
|
|
def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
|
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
|
|
mask: Optional[torch.Tensor] = None) \
|
|
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
"""
|
|
Calculate discriminator loss and call .backward() on it
|
|
:param real_batch: Tensor, a batch of real samples
|
|
:param fake_batch: Tensor, a batch of samples produced by generator
|
|
:param discr_real_pred: Tensor, discriminator output for real_batch
|
|
:param discr_fake_pred: Tensor, discriminator output for fake_batch
|
|
:param mask: Tensor, actual mask, which was at input of generator when making fake_batch
|
|
:return: total discriminator loss along with some values that might be interesting to log
|
|
"""
|
|
raise NotImplemented()
|
|
|
|
def interpolate_mask(self, mask, shape):
|
|
assert mask is not None
|
|
assert self.allow_scale_mask or shape == mask.shape[-2:]
|
|
if shape != mask.shape[-2:] and self.allow_scale_mask:
|
|
if self.mask_scale_mode == 'maxpool':
|
|
mask = F.adaptive_max_pool2d(mask, shape)
|
|
else:
|
|
mask = F.interpolate(mask, size=shape, mode=self.mask_scale_mode)
|
|
return mask
|
|
|
|
def make_r1_gp(discr_real_pred, real_batch):
|
|
if torch.is_grad_enabled():
|
|
grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0]
|
|
grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean()
|
|
else:
|
|
grad_penalty = 0
|
|
real_batch.requires_grad = False
|
|
|
|
return grad_penalty
|
|
|
|
class NonSaturatingWithR1(BaseAdversarialLoss):
|
|
def __init__(self, gp_coef=5, weight=1, mask_as_fake_target=False, allow_scale_mask=False,
|
|
mask_scale_mode='nearest', extra_mask_weight_for_gen=0,
|
|
use_unmasked_for_gen=True, use_unmasked_for_discr=True):
|
|
self.gp_coef = gp_coef
|
|
self.weight = weight
|
|
# use for discr => use for gen;
|
|
# otherwise we teach only the discr to pay attention to very small difference
|
|
assert use_unmasked_for_gen or (not use_unmasked_for_discr)
|
|
# mask as target => use unmasked for discr:
|
|
# if we don't care about unmasked regions at all
|
|
# then it doesn't matter if the value of mask_as_fake_target is true or false
|
|
assert use_unmasked_for_discr or (not mask_as_fake_target)
|
|
self.use_unmasked_for_gen = use_unmasked_for_gen
|
|
self.use_unmasked_for_discr = use_unmasked_for_discr
|
|
self.mask_as_fake_target = mask_as_fake_target
|
|
self.allow_scale_mask = allow_scale_mask
|
|
self.mask_scale_mode = mask_scale_mode
|
|
self.extra_mask_weight_for_gen = extra_mask_weight_for_gen
|
|
|
|
def generator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
|
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
|
|
mask=None) \
|
|
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
fake_loss = F.softplus(-discr_fake_pred)
|
|
if (self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0) or \
|
|
not self.use_unmasked_for_gen: # == if masked region should be treated differently
|
|
mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
|
|
if not self.use_unmasked_for_gen:
|
|
fake_loss = fake_loss * mask
|
|
else:
|
|
pixel_weights = 1 + mask * self.extra_mask_weight_for_gen
|
|
fake_loss = fake_loss * pixel_weights
|
|
|
|
return fake_loss.mean() * self.weight, dict()
|
|
|
|
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
|
generator: nn.Module, discriminator: nn.Module):
|
|
real_batch.requires_grad = True
|
|
|
|
def discriminator_loss(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
|
discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor,
|
|
mask=None) \
|
|
-> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
|
|
real_loss = F.softplus(-discr_real_pred)
|
|
grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef
|
|
fake_loss = F.softplus(discr_fake_pred)
|
|
|
|
if not self.use_unmasked_for_discr or self.mask_as_fake_target:
|
|
# == if masked region should be treated differently
|
|
mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:])
|
|
# use_unmasked_for_discr=False only makes sense for fakes;
|
|
# for reals there is no difference beetween two regions
|
|
fake_loss = fake_loss * mask
|
|
if self.mask_as_fake_target:
|
|
fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred)
|
|
|
|
sum_discr_loss = real_loss + grad_penalty + fake_loss
|
|
metrics = dict(discr_real_out=discr_real_pred.mean(),
|
|
discr_fake_out=discr_fake_pred.mean(),
|
|
discr_real_gp=grad_penalty)
|
|
return sum_discr_loss.mean(), metrics
|
|
|
|
class BCELoss(BaseAdversarialLoss):
|
|
def __init__(self, weight):
|
|
self.weight = weight
|
|
self.bce_loss = nn.BCEWithLogitsLoss()
|
|
|
|
def generator_loss(self, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
real_mask_gt = torch.zeros(discr_fake_pred.shape).to(discr_fake_pred.device)
|
|
fake_loss = self.bce_loss(discr_fake_pred, real_mask_gt) * self.weight
|
|
return fake_loss, dict()
|
|
|
|
def pre_discriminator_step(self, real_batch: torch.Tensor, fake_batch: torch.Tensor,
|
|
generator: nn.Module, discriminator: nn.Module):
|
|
real_batch.requires_grad = True
|
|
|
|
def discriminator_loss(self,
|
|
mask: torch.Tensor,
|
|
discr_real_pred: torch.Tensor,
|
|
discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
|
|
real_mask_gt = torch.zeros(discr_real_pred.shape).to(discr_real_pred.device)
|
|
sum_discr_loss = (self.bce_loss(discr_real_pred, real_mask_gt) + self.bce_loss(discr_fake_pred, mask)) / 2
|
|
metrics = dict(discr_real_out=discr_real_pred.mean(),
|
|
discr_fake_out=discr_fake_pred.mean(),
|
|
discr_real_gp=0)
|
|
return sum_discr_loss, metrics
|
|
|
|
|
|
def make_discrim_loss(kind, **kwargs):
|
|
if kind == 'r1':
|
|
return NonSaturatingWithR1(**kwargs)
|
|
elif kind == 'bce':
|
|
return BCELoss(**kwargs)
|
|
raise ValueError(f'Unknown adversarial loss kind {kind}')
|