class StreamingConvnext(ImageNetClassifier):
model_choices = {"convnext_tiny": convnext_tiny, "convnext_small": convnext_small}
def __init__(
self,
model_name: str,
tile_size: int,
loss_fn: torch.nn.functional,
train_streaming_layers: bool = True,
use_stochastic_depth: bool = False,
metrics: MetricCollection | None = None,
**kwargs,
):
assert model_name in list(StreamingConvnext.model_choices.keys())
self.model_name = model_name
self.use_stochastic_depth = use_stochastic_depth
network = StreamingConvnext.model_choices[model_name](weights="DEFAULT")
stream_network, head = network.features, torch.nn.Sequential(network.avgpool, network.classifier)
self._get_streaming_options(**kwargs)
super().__init__(
stream_network,
head,
tile_size,
loss_fn,
train_streaming_layers=train_streaming_layers,
metrics=metrics
**self.streaming_options,
)
# By default, the after_streaming_init callback turns sd off
_toggle_stochastic_depth(self.stream_network.stream_module, training=self.use_stochastic_depth)
def _get_streaming_options(self, **kwargs):
"""Set streaming defaults, but overwrite them with values of kwargs if present."""
streaming_options = {
"verbose": True,
"copy_to_gpu": False,
"statistics_on_cpu": True,
"normalize_on_gpu": True,
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"before_streaming_init_callbacks": [_set_layer_scale],
"after_streaming_init_callbacks": [_toggle_stochastic_depth]
}
self.streaming_options = {**streaming_options, **kwargs}