Skip to content

Resnet

Bases: ImageNetClassifier

Source code in lightstream\models\convnext\convnext.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class StreamingConvnext(ImageNetClassifier):
    model_choices = {"convnext_tiny": convnext_tiny, "convnext_small": convnext_small}

    def __init__(
        self,
        model_name: str,
        tile_size: int,
        loss_fn: torch.nn.functional,
        train_streaming_layers: bool = True,
        use_stochastic_depth: bool = False,
        metrics: MetricCollection | None = None,
        **kwargs,
    ):
        assert model_name in list(StreamingConvnext.model_choices.keys())

        self.model_name = model_name
        self.use_stochastic_depth = use_stochastic_depth

        network = StreamingConvnext.model_choices[model_name](weights="DEFAULT")
        stream_network, head = network.features, torch.nn.Sequential(network.avgpool, network.classifier)
        self._get_streaming_options(**kwargs)

        super().__init__(
            stream_network,
            head,
            tile_size,
            loss_fn,
            train_streaming_layers=train_streaming_layers,
            metrics=metrics
            **self.streaming_options,
        )

        # By default, the after_streaming_init callback turns sd off
        _toggle_stochastic_depth(self.stream_network.stream_module, training=self.use_stochastic_depth)

    def _get_streaming_options(self, **kwargs):
        """Set streaming defaults, but overwrite them with values of kwargs if present."""

        streaming_options = {
            "verbose": True,
            "copy_to_gpu": False,
            "statistics_on_cpu": True,
            "normalize_on_gpu": True,
            "mean": [0.485, 0.456, 0.406],
            "std": [0.229, 0.224, 0.225],
            "before_streaming_init_callbacks": [_set_layer_scale],
            "after_streaming_init_callbacks": [_toggle_stochastic_depth]
        }
        self.streaming_options = {**streaming_options, **kwargs}