diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index 12f2bf8bb..08a9d2d16 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import numpy as np import timm @@ -19,9 +19,9 @@ def _get_architecture( arch_name: str, - weights: str or WeightsEnum = "DEFAULT", + weights: str | WeightsEnum = "DEFAULT", **kwargs: dict, -) -> list[nn.Sequential, ...] | nn.Sequential: +) -> torch.nn.ModuleList | nn.Sequential: """Retrieve a CNN model architecture. This function fetches a Convolutional Neural Network (CNN) model architecture, @@ -38,7 +38,7 @@ def _get_architecture( Key-word arguments. Returns: - list[nn.Sequential, ...] | nn.Sequential: + list[nn.Sequential] | nn.Sequential: A list of PyTorch network layers wrapped with `nn.Sequential`. Raises: @@ -94,7 +94,7 @@ def _get_timm_architecture( arch_name: str, *, pretrained: bool, -) -> list[nn.Sequential, ...] | nn.Sequential: +) -> torch.nn.ModuleList | nn.Sequential: """Retrieve a timm model architecture. This function fetches a model architecture from the timm library, specifically for @@ -124,6 +124,7 @@ def _get_timm_architecture( model = timm.create_model(arch_name, pretrained=pretrained) return nn.Sequential(*list(model.children())[:-1]) + arch_map: dict[str, dict] = {} arch_map = { # UNI tile encoder: https://huggingface.co/MahmoodLab/UNI "UNI": { @@ -306,7 +307,9 @@ def __init__(self: CNNModel, backbone: str, num_classes: int = 1) -> None: # pylint: disable=W0221 # because abc is generic, this is actual definition - def forward(self: CNNModel, imgs: torch.Tensor) -> torch.Tensor: + def forward( + self: CNNModel, *args: tuple[Any, ...], **kwargs: dict + ) -> torch.Tensor | None: """Pass input data through the model. Args: @@ -318,6 +321,9 @@ def forward(self: CNNModel, imgs: torch.Tensor) -> torch.Tensor: The output logits after passing through the model. """ + imgs = args[0] + if imgs is None: + return None feat = self.feat_extract(imgs) gap_feat = self.pool(feat) gap_feat = torch.flatten(gap_feat, 1) @@ -431,7 +437,9 @@ def __init__( # pylint: disable=W0221 # because abc is generic, this is actual definition - def forward(self: TimmModel, imgs: torch.Tensor) -> torch.Tensor: + def forward( + self: TimmModel, *args: tuple[Any, ...], **kwargs: dict + ) -> torch.Tensor | None: """Pass input data through the model. Args: @@ -443,6 +451,9 @@ def forward(self: TimmModel, imgs: torch.Tensor) -> torch.Tensor: The output logits after passing through the model. """ + imgs = args[0] + if imgs is None: + return None feat = self.feat_extract(imgs) feat = torch.flatten(feat, 1) logit = self.classifier(feat) @@ -552,7 +563,9 @@ def __init__(self: CNNBackbone, backbone: str) -> None: # pylint: disable=W0221 # because abc is generic, this is actual definition - def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor: + def forward( + self: CNNBackbone, *args: tuple[Any, ...], **kwargs: dict + ) -> torch.Tensor | None: """Pass input data through the model. Args: @@ -564,6 +577,9 @@ def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor: The extracted features. """ + imgs = args[0] + if imgs is None: + return None feat = self.feat_extract(imgs) gap_feat = self.pool(feat) return torch.flatten(gap_feat, 1) @@ -645,7 +661,9 @@ def __init__(self: TimmBackbone, backbone: str, *, pretrained: bool) -> None: # pylint: disable=W0221 # because abc is generic, this is actual definition - def forward(self: TimmBackbone, imgs: torch.Tensor) -> torch.Tensor: + def forward( + self: TimmBackbone, *args: tuple[Any, ...], **kwargs: dict + ) -> torch.Tensor | None: """Pass input data through the model. Args: @@ -657,6 +675,9 @@ def forward(self: TimmBackbone, imgs: torch.Tensor) -> torch.Tensor: The extracted features. """ + imgs = args[0] + if imgs is None: + return None feats = self.feat_extract(imgs) return torch.flatten(feats, 1) diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index a8a8f7262..53da54818 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -80,9 +80,9 @@ def forward( @abstractmethod def infer_batch( model: torch.nn.Module, - batch_data: np.ndarray, + batch_data: torch.Tensor, device: str, - ) -> None: + ) -> None | dict[str, np.ndarray] | list[dict[str, np.ndarray]]: """Run inference on an input batch. Contains logic for forward operation as well as I/O aggregation. @@ -90,7 +90,7 @@ def infer_batch( Args: model (nn.Module): PyTorch defined model. - batch_data (np.ndarray): + batch_data (torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. device (str): @@ -227,4 +227,5 @@ def load_weights_from_file(self: ModelABC, weights: str | Path) -> torch.nn.Modu # ! assume to be saved in single GPU mode # always load on to the CPU saved_state_dict = torch.load(weights, map_location="cpu") - return super().load_state_dict(saved_state_dict, strict=True) + super().load_state_dict(saved_state_dict, strict=True) + return self