Skip to content

StreamingModule

Source code in lightstream\modules\constructor.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
class StreamingConstructor:
    def __init__(
        self,
        model: torch.nn.modules,
        tile_size: int,
        verbose: bool = False,
        deterministic: bool = False,
        saliency: bool = False,
        copy_to_gpu: bool = False,
        statistics_on_cpu: bool = False,
        normalize_on_gpu: bool = False,
        mean: list[float, float, float] | None = None,
        std: list[float, float, float] | None = None,
        tile_cache: dict | None = None,
        add_keep_modules: list[torch.nn.modules] | None = None,
        before_streaming_init_callbacks: list[Callable[[torch.nn.modules], None], ...] | None = None,
        after_streaming_init_callbacks: list[Callable[[torch.nn.modules], None], ...] | None = None,
    ):
        self.model = model
        self.model_copy = deepcopy(self.model)
        self.state_dict = self.save_parameters()

        self.tile_size = tile_size
        self.verbose = verbose
        self.deterministic = deterministic
        self.saliency = saliency
        self.copy_to_gpu = copy_to_gpu
        self.statistics_on_cpu = statistics_on_cpu
        self.normalize_on_gpu = normalize_on_gpu
        self.mean = mean
        self.std = std
        self.tile_cache = tile_cache

        self.before_streaming_init_callbacks = before_streaming_init_callbacks or []
        self.after_streaming_init_callbacks = after_streaming_init_callbacks or []

        self.keep_modules = [
            torch.nn.Conv1d,
            torch.nn.Conv2d,
            torch.nn.Conv3d,
            torch.nn.AvgPool1d,
            torch.nn.AvgPool2d,
            torch.nn.AvgPool3d,
            torch.nn.MaxPool1d,
            torch.nn.MaxPool2d,
            torch.nn.MaxPool3d,
        ]

        if add_keep_modules is not None:
            self.add_modules_to_keep(add_keep_modules)

        if not self.statistics_on_cpu:
            # Move to cuda manually if statistics are computed on gpu
            device = torch.device("cuda")
            self.model.to(device)

    def add_modules_to_keep(self, module_list: list):
        """Add extra layers to keep during streaming tile calculations

        Modules in the keep_modules list will not be set to nn.Identity() during streaming initialization
        Parameters
        ----------
        module_list : list
            A list of torch modules to add to the keep_modules list.
        """

        self.keep_modules.extend(module_list)

    def prepare_streaming_model(self):
        """Run pre and postprocessing for tile lost calculations
        Returns
        -------
        sCNN : torch.nn.modules
            The streaming module
        """

        # If tile cache is available, it has already been initialized successfully once
        if self.tile_cache:
            return self.create_streaming_model()

        print("")
        # Prepare for streaming tile statistics calculations
        print("Converting modules to nn.Identity()")
        self.convert_to_identity(self.model)
        # execute any callbacks that further preprocess the model
        print("Executing pre-streaming initialization callbacks (if any):")
        self._execute_before_callbacks()

        print("Initializing streaming model")
        sCNN = self.create_streaming_model()

        # check self.stream_network, and reload the proper weights
        print("Restoring model weights")
        self.restore_model_layers(self.model_copy, sCNN.stream_module)
        sCNN.stream_module.load_state_dict(self.state_dict)

        print("Executing post-streaming initialization callbacks (if any):")
        self._execute_after_callbacks()
        return sCNN

    def _execute_before_callbacks(self):
        for cb_func in self.before_streaming_init_callbacks:
            print(f"Executing callback function {cb_func}")
            cb_func(self.model)
        print("")

    def _execute_after_callbacks(self):
        for cb_func in self.after_streaming_init_callbacks:
            print(f"Executing callback function {cb_func}")
            cb_func(self.model)
        print("")

    def create_streaming_model(self):
        return StreamingCNN(
            self.model,
            tile_shape=(1, 3, self.tile_size, self.tile_size),
            deterministic=self.deterministic,
            saliency=self.saliency,
            copy_to_gpu=self.copy_to_gpu,
            verbose=self.verbose,
            statistics_on_cpu=self.statistics_on_cpu,
            normalize_on_gpu=self.normalize_on_gpu,
            mean=self.mean,
            std=self.std,
            state_dict=self.tile_cache,
        )

    def save_parameters(self):
        state_dict = self.model.state_dict()
        state_dict = deepcopy(state_dict)
        return state_dict

    def convert_to_identity(self, model: torch.nn.modules):
        """Convert non-conv and non-local pooling layers to identity

        Parameters
        ----------
        model : torch.nn.Sequential
            The model to substitute
        """

        for n, module in model.named_children():
            if len(list(module.children())) > 0:
                # compound module, go inside it
                self.convert_to_identity(module)
                continue

            # if new module is assigned to a variable, e.g. new = nn.Identity(), then it's considered a duplicate in
            # module.named_children used later. Instead, we use in-place assignment, so each new module is unique
            if not isinstance(module, tuple(self.keep_modules)):
                try:
                    n = int(n)
                    model[n] = torch.nn.Identity()
                except ValueError:
                    setattr(model, str(n), torch.nn.Identity())

    def restore_model_layers(self, model_ref, model_rep):
        """Restore model layers from Identity to what they were before

        This function requires an exact copy of the model (model_ref) before its layers were set to nn.Identity()
        (model_rep)

        Parameters
        ----------
        model_ref : torch.nn.modules
            The copy of the model before it was set to nn.Identity()
        model_rep : torch.nn.modules
            The stream_module attribute within the streaming model that were set to nn.Identity
        """

        for ref, rep in zip(model_ref.named_children(), model_rep.named_children()):
            n_ref, module_ref = ref
            n_rep, module_rep = rep

            if len(list(module_ref.children())) > 0:
                # compound module, go inside it
                self.restore_model_layers(module_ref, module_rep)
                continue

            if isinstance(module_rep, torch.nn.Identity):
                # simple module
                try:
                    n_ref = int(n_ref)
                    model_rep[n_rep] = model_ref[n_ref]
                except (ValueError, TypeError):
                    try:
                        setattr(model_rep, n_rep, model_ref[int(n_ref)])
                    except (ValueError, TypeError):
                        # Try setting it through block dot operations
                        setattr(model_rep, n_rep, getattr(model_ref, n_ref))

add_modules_to_keep(module_list)

Add extra layers to keep during streaming tile calculations

Modules in the keep_modules list will not be set to nn.Identity() during streaming initialization

Parameters:

Name Type Description Default
module_list list

A list of torch modules to add to the keep_modules list.

required
Source code in lightstream\modules\constructor.py
76
77
78
79
80
81
82
83
84
85
86
def add_modules_to_keep(self, module_list: list):
    """Add extra layers to keep during streaming tile calculations

    Modules in the keep_modules list will not be set to nn.Identity() during streaming initialization
    Parameters
    ----------
    module_list : list
        A list of torch modules to add to the keep_modules list.
    """

    self.keep_modules.extend(module_list)

convert_to_identity(model)

Convert non-conv and non-local pooling layers to identity

Parameters:

Name Type Description Default
model Sequential

The model to substitute

required
Source code in lightstream\modules\constructor.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def convert_to_identity(self, model: torch.nn.modules):
    """Convert non-conv and non-local pooling layers to identity

    Parameters
    ----------
    model : torch.nn.Sequential
        The model to substitute
    """

    for n, module in model.named_children():
        if len(list(module.children())) > 0:
            # compound module, go inside it
            self.convert_to_identity(module)
            continue

        # if new module is assigned to a variable, e.g. new = nn.Identity(), then it's considered a duplicate in
        # module.named_children used later. Instead, we use in-place assignment, so each new module is unique
        if not isinstance(module, tuple(self.keep_modules)):
            try:
                n = int(n)
                model[n] = torch.nn.Identity()
            except ValueError:
                setattr(model, str(n), torch.nn.Identity())

prepare_streaming_model()

Run pre and postprocessing for tile lost calculations

Returns:

Name Type Description
sCNN modules

The streaming module

Source code in lightstream\modules\constructor.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def prepare_streaming_model(self):
    """Run pre and postprocessing for tile lost calculations
    Returns
    -------
    sCNN : torch.nn.modules
        The streaming module
    """

    # If tile cache is available, it has already been initialized successfully once
    if self.tile_cache:
        return self.create_streaming_model()

    print("")
    # Prepare for streaming tile statistics calculations
    print("Converting modules to nn.Identity()")
    self.convert_to_identity(self.model)
    # execute any callbacks that further preprocess the model
    print("Executing pre-streaming initialization callbacks (if any):")
    self._execute_before_callbacks()

    print("Initializing streaming model")
    sCNN = self.create_streaming_model()

    # check self.stream_network, and reload the proper weights
    print("Restoring model weights")
    self.restore_model_layers(self.model_copy, sCNN.stream_module)
    sCNN.stream_module.load_state_dict(self.state_dict)

    print("Executing post-streaming initialization callbacks (if any):")
    self._execute_after_callbacks()
    return sCNN

restore_model_layers(model_ref, model_rep)

Restore model layers from Identity to what they were before

This function requires an exact copy of the model (model_ref) before its layers were set to nn.Identity() (model_rep)

Parameters:

Name Type Description Default
model_ref modules

The copy of the model before it was set to nn.Identity()

required
model_rep modules

The stream_module attribute within the streaming model that were set to nn.Identity

required
Source code in lightstream\modules\constructor.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def restore_model_layers(self, model_ref, model_rep):
    """Restore model layers from Identity to what they were before

    This function requires an exact copy of the model (model_ref) before its layers were set to nn.Identity()
    (model_rep)

    Parameters
    ----------
    model_ref : torch.nn.modules
        The copy of the model before it was set to nn.Identity()
    model_rep : torch.nn.modules
        The stream_module attribute within the streaming model that were set to nn.Identity
    """

    for ref, rep in zip(model_ref.named_children(), model_rep.named_children()):
        n_ref, module_ref = ref
        n_rep, module_rep = rep

        if len(list(module_ref.children())) > 0:
            # compound module, go inside it
            self.restore_model_layers(module_ref, module_rep)
            continue

        if isinstance(module_rep, torch.nn.Identity):
            # simple module
            try:
                n_ref = int(n_ref)
                model_rep[n_rep] = model_ref[n_ref]
            except (ValueError, TypeError):
                try:
                    setattr(model_rep, n_rep, model_ref[int(n_ref)])
                except (ValueError, TypeError):
                    # Try setting it through block dot operations
                    setattr(model_rep, n_rep, getattr(model_ref, n_ref))