class StreamingResNet(ImageNetClassifier):
# Resnet minimal tile size based on tile statistics calculations:
# resnet18 : 960
model_choices = {"resnet18": resnet18, "resnet34": resnet34, "resnet50": resnet50}
def __init__(
self,
model_name: str,
tile_size: int,
loss_fn: torch.nn.functional,
train_streaming_layers: bool = True,
metrics: MetricCollection | None = None,
**kwargs
):
assert model_name in list(StreamingResNet.model_choices.keys())
network = StreamingResNet.model_choices[model_name](weights="DEFAULT")
stream_network, head = split_resnet(network, num_classes=kwargs.pop("num_classes", 1000))
self._get_streaming_options(**kwargs)
print("metrics", metrics)
super().__init__(
stream_network,
head,
tile_size,
loss_fn,
train_streaming_layers=train_streaming_layers,
metrics=metrics,
**self.streaming_options,
)
def _get_streaming_options(self, **kwargs):
"""Set streaming defaults, but overwrite them with values of kwargs if present."""
# We need to add torch.nn.Batchnorm to the keep modules, because of some in-place ops error if we don't
# https://discuss.pytorch.org/t/register-full-backward-hook-for-residual-connection/146850
streaming_options = {
"verbose": True,
"copy_to_gpu": False,
"statistics_on_cpu": True,
"normalize_on_gpu": True,
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"add_keep_modules": [torch.nn.BatchNorm2d],
}
self.streaming_options = {**streaming_options, **kwargs}