Skip to content

StreamingModule

Bases: LightningModule

Source code in lightstream\modules\streaming.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
class StreamingModule(L.LightningModule):
    def __init__(self, stream_network, tile_size, train_streaming_layers=True, **kwargs):
        super().__init__()
        self.train_streaming_layers = train_streaming_layers
        # self._stream_module = stream_network

        # StreamingCNN options
        self._tile_size = tile_size
        self.tile_cache_dir = kwargs.pop("tile_cache_dir", Path.cwd())
        self.tile_cache_fname = kwargs.pop("tile_cache_fname", None)

        # Load the tile cache state dict if present
        tile_cache = self.load_tile_cache_if_needed()

        # Initialize the streaming network
        self._constructor_opts = kwargs
        self.constructor = StreamingConstructor(
            stream_network, self.tile_size, tile_cache=tile_cache, **self._constructor_opts
        )
        self.copy_to_gpu = self.constructor.copy_to_gpu
        self.stream_network = self.constructor.prepare_streaming_model()

        self.save_tile_cache_if_needed()
        self.params = self.get_trainable_params()

    @property
    def tile_size(self):
        return self._tile_size

    @tile_size.setter
    def tile_size(self, new_tile_size):
        self._tile_size = new_tile_size

    def freeze_streaming_normalization_layers(self):
        """Do not use normalization layers within lightstream, only local ops are allowed"""
        freeze_layers = [
            l
            for l in self.stream_network.stream_module.modules()
            if isinstance(l, (torch.nn.BatchNorm2d, torch.nn.LayerNorm))
        ]

        for mod in freeze_layers:
            mod.eval()

    def on_train_epoch_start(self) -> None:
        """on_train_epoch_start hook

        Do not override this method. Instead, call the parent class using super().on_train_start if you want
        to add this hook into your pipelines

        """
        self.freeze_streaming_normalization_layers()

    def prepare_start_for_streaming(self):

        # Update streaming to put all the inputs/tensors on the right device
        self.stream_network.device = self.device
        self.stream_network.mean = self.stream_network.mean.to(self.device, non_blocking=True)
        self.stream_network.std = self.stream_network.std.to(self.device, non_blocking=True)
        if self.trainer.precision == "16-mixed":
            self.stream_network.dtype = torch.float16
        elif self.trainer.precision == "bf16-mixed":
            self.stream_network.dtype = torch.float16
        else:
            self.stream_network.dtype = self.dtype

    def on_validation_start(self):
        """on_validation_start hook

        Do not override this method. Instead, call the parent class using super().on_train_start if you want
        to add this hook into your pipelines

        """
        self.prepare_start_for_streaming()

    def on_train_start(self):
        """on_train_start hook

        Do not override this method. Instead, call the parent class using super().on_train_start if you want
        to add this hook into your pipelines

        """
        self.prepare_start_for_streaming()


    def on_test_start(self):
        """on_test_start hook

        Do not override this method. Instead, call the parent class using super().on_train_start if you want
        to add this hook into your pipelines

        """
        self.prepare_start_for_streaming()


    def on_predict_start(self):
        """on_predict_start hook

        Do not override this method. Instead, call the parent class using super().on_train_start if you want
        to add this hook into your pipelines

        """
        self.prepare_start_for_streaming()


    def disable_streaming_hooks(self):
        """Disable streaming hooks and replace streamingconv2d  with conv2d modules

        This will still use the StreamingCNN backward and forward functions, but the memory gains from gradient
        checkpointing will be turned off.
        """
        self.stream_network.disable()

    def enable_streaming_hooks(self):
        """Enable streaming hooks and use streamingconv2d modules"""
        self.stream_network.enable()

    def forward_streaming(self, x):
        """

        Parameters
        ----------
        x : torch.Tensor
        The input tensor in [1,C,H,W] format

        Returns
        -------
        out: torch.Tensor
        The output of the streaming model

        """
        return self.stream_network.forward(x)

    def backward_streaming(self, image, gradient):
        """Perform the backward pass using the streaming network

        Backward only if streaming is turned on.
        This method is primarily a convenience function

        Parameters
        ----------
        image: torch.Tensor
            The input image in [1,C,H,W] format
        gradient: torch.Tensor
            The gradient of the next layer in the model to continue backpropagation with

        Returns
        -------

        """

        # If requires_grad is set to false, .backward() in streaming causes errors or overhead, so use a bool
        if self.train_streaming_layers:
            self.stream_network.backward(image, gradient)

    def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
        raise NotImplementedError

    def backward(self, loss: Tensor, *args: Any, **kwargs: Any) -> None:
        raise NotImplementedError

    def configure_tile_stride(self):
        """
        Helper function that returns the tile stride during streaming.

        Streaming assumes that the input image is perfectly divisible with the network output stride or the
        tile stride. This function will return the tile stride, which can then be used within data processing pipelines
        to pad/crop images to a multiple of the tile stride.

        Examples:

        Returns
        -------
        tile_stride: numpy.ndarray
            the tile stride.


        """
        stride = self.tile_size - (
            self.stream_network.tile_gradient_lost.left + self.stream_network.tile_gradient_lost.right
        )
        stride = stride // self.stream_network.output_stride[-1]
        stride *= self.stream_network.output_stride[-1]
        return stride.detach().cpu().numpy()

    def get_trainable_params(self):
        """Get trainable parameters for the entire model

        If self.streaming_layers is True, then the parameters of the streaming network will be trained.
        Otherwise, the parameters will be left untrained (no gradients will be collected)

        """

        if self.train_streaming_layers:
            params = list(self.stream_network.stream_module.parameters())
            return params
        else:
            print("WARNING: Streaming network will not be trained")
            for param in self.stream_network.stream_module.parameters():
                param.requires_grad = False


    def _remove_streaming_network(self):
        """Converts the streaming network into a non-streaming network

        The former streaming encoder can be addressed as self.stream_network
        This function is currently untested and breaks the class, since there is no way to rebuild the streaming network
        other than calling a new class directly.

        """

        # Convert streamingConv2D into regular Conv2D and turn off streaming hooks
        self.disable_streaming_hooks()
        temp = self.stream_network.stream_module

        # torch modules cannot be overridden normally, so delete and reassign
        del self.stream_network
        self.stream_network = temp

    def save_tile_cache_if_needed(self, overwrite=False):
        """
        Writes the tile cache to a file, so it does not have to be recomputed

        The tile cache is normally calculated for each run.
        However, this can take a long time. By writing it to a file it can be reloaded without the need
        for recomputation.

        Limitations:
        This only works for the exact same model and for a single tile size. If the streaming part of the model
        changes, or if the tile size is changed, it will no longer work.

        """
        if self.tile_cache_fname is None:
            self.tile_cache_fname = "tile_cache_" + "1_3_" + str(self.tile_size) + "_" + str(self.tile_size)
        write_path = Path(self.tile_cache_dir) / Path(self.tile_cache_fname)

        if Path(self.tile_cache_dir).exists():
            if write_path.exists() and not overwrite:
                print("previous tile cache found and overwrite is false, not saving")

            elif self.global_rank == 0:
                print(f"writing streaming cache file to {str(write_path)}")
                torch.save(self.stream_network.get_tile_cache(), str(write_path))

            else:
                print("")
        else:
            raise NotADirectoryError(f"Did not find {self.tile_cache_dir} or does not exist")

    def load_tile_cache_if_needed(self, use_tile_cache: bool = True):
        """
        Load the tile cache for the model from the read_dir

        Parameters
        ----------
        use_tile_cache : bool
            Whether to use the tile cache file and load it into the streaming module

        Returns
        ---------
        state_dict : torch.state_dict | None
            The state dict if present
        """

        if self.tile_cache_fname is None:
            self.tile_cache_fname = "tile_cache_" + "1_3_" + str(self.tile_size) + "_" + str(self.tile_size)

        tile_cache_loc = Path(self.tile_cache_dir) / Path(self.tile_cache_fname)

        if tile_cache_loc.exists() and use_tile_cache:
            print("Loading tile cache from", tile_cache_loc)
            state_dict = torch.load(str(tile_cache_loc), map_location=lambda storage, loc: storage)
        else:
            print("No tile cache found, calculating it now")
            state_dict = None

        return state_dict

backward_streaming(image, gradient)

Perform the backward pass using the streaming network

Backward only if streaming is turned on. This method is primarily a convenience function

Parameters:

Name Type Description Default
image

The input image in [1,C,H,W] format

required
gradient

The gradient of the next layer in the model to continue backpropagation with

required
Source code in lightstream\modules\streaming.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def backward_streaming(self, image, gradient):
    """Perform the backward pass using the streaming network

    Backward only if streaming is turned on.
    This method is primarily a convenience function

    Parameters
    ----------
    image: torch.Tensor
        The input image in [1,C,H,W] format
    gradient: torch.Tensor
        The gradient of the next layer in the model to continue backpropagation with

    Returns
    -------

    """

    # If requires_grad is set to false, .backward() in streaming causes errors or overhead, so use a bool
    if self.train_streaming_layers:
        self.stream_network.backward(image, gradient)

configure_tile_stride()

Helper function that returns the tile stride during streaming.

Streaming assumes that the input image is perfectly divisible with the network output stride or the tile stride. This function will return the tile stride, which can then be used within data processing pipelines to pad/crop images to a multiple of the tile stride.

Examples:

Returns:

Name Type Description
tile_stride ndarray

the tile stride.

Source code in lightstream\modules\streaming.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def configure_tile_stride(self):
    """
    Helper function that returns the tile stride during streaming.

    Streaming assumes that the input image is perfectly divisible with the network output stride or the
    tile stride. This function will return the tile stride, which can then be used within data processing pipelines
    to pad/crop images to a multiple of the tile stride.

    Examples:

    Returns
    -------
    tile_stride: numpy.ndarray
        the tile stride.


    """
    stride = self.tile_size - (
        self.stream_network.tile_gradient_lost.left + self.stream_network.tile_gradient_lost.right
    )
    stride = stride // self.stream_network.output_stride[-1]
    stride *= self.stream_network.output_stride[-1]
    return stride.detach().cpu().numpy()

disable_streaming_hooks()

Disable streaming hooks and replace streamingconv2d with conv2d modules

This will still use the StreamingCNN backward and forward functions, but the memory gains from gradient checkpointing will be turned off.

Source code in lightstream\modules\streaming.py
119
120
121
122
123
124
125
def disable_streaming_hooks(self):
    """Disable streaming hooks and replace streamingconv2d  with conv2d modules

    This will still use the StreamingCNN backward and forward functions, but the memory gains from gradient
    checkpointing will be turned off.
    """
    self.stream_network.disable()

enable_streaming_hooks()

Enable streaming hooks and use streamingconv2d modules

Source code in lightstream\modules\streaming.py
127
128
129
def enable_streaming_hooks(self):
    """Enable streaming hooks and use streamingconv2d modules"""
    self.stream_network.enable()

forward_streaming(x)

Parameters:

Name Type Description Default
x Tensor
required
The
required

Returns:

Name Type Description
out Tensor
The output of the streaming model
Source code in lightstream\modules\streaming.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def forward_streaming(self, x):
    """

    Parameters
    ----------
    x : torch.Tensor
    The input tensor in [1,C,H,W] format

    Returns
    -------
    out: torch.Tensor
    The output of the streaming model

    """
    return self.stream_network.forward(x)

freeze_streaming_normalization_layers()

Do not use normalization layers within lightstream, only local ops are allowed

Source code in lightstream\modules\streaming.py
47
48
49
50
51
52
53
54
55
56
def freeze_streaming_normalization_layers(self):
    """Do not use normalization layers within lightstream, only local ops are allowed"""
    freeze_layers = [
        l
        for l in self.stream_network.stream_module.modules()
        if isinstance(l, (torch.nn.BatchNorm2d, torch.nn.LayerNorm))
    ]

    for mod in freeze_layers:
        mod.eval()

get_trainable_params()

Get trainable parameters for the entire model

If self.streaming_layers is True, then the parameters of the streaming network will be trained. Otherwise, the parameters will be left untrained (no gradients will be collected)

Source code in lightstream\modules\streaming.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def get_trainable_params(self):
    """Get trainable parameters for the entire model

    If self.streaming_layers is True, then the parameters of the streaming network will be trained.
    Otherwise, the parameters will be left untrained (no gradients will be collected)

    """

    if self.train_streaming_layers:
        params = list(self.stream_network.stream_module.parameters())
        return params
    else:
        print("WARNING: Streaming network will not be trained")
        for param in self.stream_network.stream_module.parameters():
            param.requires_grad = False

load_tile_cache_if_needed(use_tile_cache=True)

Load the tile cache for the model from the read_dir

Parameters:

Name Type Description Default
use_tile_cache bool

Whether to use the tile cache file and load it into the streaming module

True

Returns:

Name Type Description
state_dict state_dict | None

The state dict if present

Source code in lightstream\modules\streaming.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def load_tile_cache_if_needed(self, use_tile_cache: bool = True):
    """
    Load the tile cache for the model from the read_dir

    Parameters
    ----------
    use_tile_cache : bool
        Whether to use the tile cache file and load it into the streaming module

    Returns
    ---------
    state_dict : torch.state_dict | None
        The state dict if present
    """

    if self.tile_cache_fname is None:
        self.tile_cache_fname = "tile_cache_" + "1_3_" + str(self.tile_size) + "_" + str(self.tile_size)

    tile_cache_loc = Path(self.tile_cache_dir) / Path(self.tile_cache_fname)

    if tile_cache_loc.exists() and use_tile_cache:
        print("Loading tile cache from", tile_cache_loc)
        state_dict = torch.load(str(tile_cache_loc), map_location=lambda storage, loc: storage)
    else:
        print("No tile cache found, calculating it now")
        state_dict = None

    return state_dict

on_predict_start()

on_predict_start hook

Do not override this method. Instead, call the parent class using super().on_train_start if you want to add this hook into your pipelines

Source code in lightstream\modules\streaming.py
109
110
111
112
113
114
115
116
def on_predict_start(self):
    """on_predict_start hook

    Do not override this method. Instead, call the parent class using super().on_train_start if you want
    to add this hook into your pipelines

    """
    self.prepare_start_for_streaming()

on_test_start()

on_test_start hook

Do not override this method. Instead, call the parent class using super().on_train_start if you want to add this hook into your pipelines

Source code in lightstream\modules\streaming.py
 99
100
101
102
103
104
105
106
def on_test_start(self):
    """on_test_start hook

    Do not override this method. Instead, call the parent class using super().on_train_start if you want
    to add this hook into your pipelines

    """
    self.prepare_start_for_streaming()

on_train_epoch_start()

on_train_epoch_start hook

Do not override this method. Instead, call the parent class using super().on_train_start if you want to add this hook into your pipelines

Source code in lightstream\modules\streaming.py
58
59
60
61
62
63
64
65
def on_train_epoch_start(self) -> None:
    """on_train_epoch_start hook

    Do not override this method. Instead, call the parent class using super().on_train_start if you want
    to add this hook into your pipelines

    """
    self.freeze_streaming_normalization_layers()

on_train_start()

on_train_start hook

Do not override this method. Instead, call the parent class using super().on_train_start if you want to add this hook into your pipelines

Source code in lightstream\modules\streaming.py
89
90
91
92
93
94
95
96
def on_train_start(self):
    """on_train_start hook

    Do not override this method. Instead, call the parent class using super().on_train_start if you want
    to add this hook into your pipelines

    """
    self.prepare_start_for_streaming()

on_validation_start()

on_validation_start hook

Do not override this method. Instead, call the parent class using super().on_train_start if you want to add this hook into your pipelines

Source code in lightstream\modules\streaming.py
80
81
82
83
84
85
86
87
def on_validation_start(self):
    """on_validation_start hook

    Do not override this method. Instead, call the parent class using super().on_train_start if you want
    to add this hook into your pipelines

    """
    self.prepare_start_for_streaming()

save_tile_cache_if_needed(overwrite=False)

Writes the tile cache to a file, so it does not have to be recomputed

The tile cache is normally calculated for each run. However, this can take a long time. By writing it to a file it can be reloaded without the need for recomputation.

Limitations: This only works for the exact same model and for a single tile size. If the streaming part of the model changes, or if the tile size is changed, it will no longer work.

Source code in lightstream\modules\streaming.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def save_tile_cache_if_needed(self, overwrite=False):
    """
    Writes the tile cache to a file, so it does not have to be recomputed

    The tile cache is normally calculated for each run.
    However, this can take a long time. By writing it to a file it can be reloaded without the need
    for recomputation.

    Limitations:
    This only works for the exact same model and for a single tile size. If the streaming part of the model
    changes, or if the tile size is changed, it will no longer work.

    """
    if self.tile_cache_fname is None:
        self.tile_cache_fname = "tile_cache_" + "1_3_" + str(self.tile_size) + "_" + str(self.tile_size)
    write_path = Path(self.tile_cache_dir) / Path(self.tile_cache_fname)

    if Path(self.tile_cache_dir).exists():
        if write_path.exists() and not overwrite:
            print("previous tile cache found and overwrite is false, not saving")

        elif self.global_rank == 0:
            print(f"writing streaming cache file to {str(write_path)}")
            torch.save(self.stream_network.get_tile_cache(), str(write_path))

        else:
            print("")
    else:
        raise NotADirectoryError(f"Did not find {self.tile_cache_dir} or does not exist")