Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions tiatoolbox/models/architecture/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np
import timm
Expand All @@ -19,9 +19,9 @@

def _get_architecture(
arch_name: str,
weights: str or WeightsEnum = "DEFAULT",
weights: str | WeightsEnum = "DEFAULT",
**kwargs: dict,
) -> list[nn.Sequential, ...] | nn.Sequential:
) -> torch.nn.ModuleList | nn.Sequential:
"""Retrieve a CNN model architecture.

This function fetches a Convolutional Neural Network (CNN) model architecture,
Expand All @@ -38,7 +38,7 @@ def _get_architecture(
Key-word arguments.

Returns:
list[nn.Sequential, ...] | nn.Sequential:
list[nn.Sequential] | nn.Sequential:
A list of PyTorch network layers wrapped with `nn.Sequential`.

Raises:
Expand Down Expand Up @@ -94,7 +94,7 @@ def _get_timm_architecture(
arch_name: str,
*,
pretrained: bool,
) -> list[nn.Sequential, ...] | nn.Sequential:
) -> torch.nn.ModuleList | nn.Sequential:
"""Retrieve a timm model architecture.

This function fetches a model architecture from the timm library, specifically for
Expand Down Expand Up @@ -124,6 +124,7 @@ def _get_timm_architecture(
model = timm.create_model(arch_name, pretrained=pretrained)
return nn.Sequential(*list(model.children())[:-1])

arch_map: dict[str, dict] = {}
arch_map = {
# UNI tile encoder: https://huggingface.co/MahmoodLab/UNI
"UNI": {
Expand Down Expand Up @@ -306,7 +307,9 @@ def __init__(self: CNNModel, backbone: str, num_classes: int = 1) -> None:

# pylint: disable=W0221
# because abc is generic, this is actual definition
def forward(self: CNNModel, imgs: torch.Tensor) -> torch.Tensor:
def forward(
self: CNNModel, *args: tuple[Any, ...], **kwargs: dict
) -> torch.Tensor | None:
"""Pass input data through the model.

Args:
Expand All @@ -318,6 +321,9 @@ def forward(self: CNNModel, imgs: torch.Tensor) -> torch.Tensor:
The output logits after passing through the model.

"""
imgs = args[0]
if imgs is None:
return None
feat = self.feat_extract(imgs)
gap_feat = self.pool(feat)
gap_feat = torch.flatten(gap_feat, 1)
Expand Down Expand Up @@ -431,7 +437,9 @@ def __init__(

# pylint: disable=W0221
# because abc is generic, this is actual definition
def forward(self: TimmModel, imgs: torch.Tensor) -> torch.Tensor:
def forward(
self: TimmModel, *args: tuple[Any, ...], **kwargs: dict
) -> torch.Tensor | None:
"""Pass input data through the model.

Args:
Expand All @@ -443,6 +451,9 @@ def forward(self: TimmModel, imgs: torch.Tensor) -> torch.Tensor:
The output logits after passing through the model.

"""
imgs = args[0]
if imgs is None:
return None
feat = self.feat_extract(imgs)
feat = torch.flatten(feat, 1)
logit = self.classifier(feat)
Expand Down Expand Up @@ -552,7 +563,9 @@ def __init__(self: CNNBackbone, backbone: str) -> None:

# pylint: disable=W0221
# because abc is generic, this is actual definition
def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor:
def forward(
self: CNNBackbone, *args: tuple[Any, ...], **kwargs: dict
) -> torch.Tensor | None:
"""Pass input data through the model.

Args:
Expand All @@ -564,6 +577,9 @@ def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor:
The extracted features.

"""
imgs = args[0]
if imgs is None:
return None
feat = self.feat_extract(imgs)
gap_feat = self.pool(feat)
return torch.flatten(gap_feat, 1)
Expand Down Expand Up @@ -645,7 +661,9 @@ def __init__(self: TimmBackbone, backbone: str, *, pretrained: bool) -> None:

# pylint: disable=W0221
# because abc is generic, this is actual definition
def forward(self: TimmBackbone, imgs: torch.Tensor) -> torch.Tensor:
def forward(
self: TimmBackbone, *args: tuple[Any, ...], **kwargs: dict
) -> torch.Tensor | None:
"""Pass input data through the model.

Args:
Expand All @@ -657,6 +675,9 @@ def forward(self: TimmBackbone, imgs: torch.Tensor) -> torch.Tensor:
The extracted features.

"""
imgs = args[0]
if imgs is None:
return None
feats = self.feat_extract(imgs)
return torch.flatten(feats, 1)

Expand Down
9 changes: 5 additions & 4 deletions tiatoolbox/models/models_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,17 @@ def forward(
@abstractmethod
def infer_batch(
model: torch.nn.Module,
batch_data: np.ndarray,
batch_data: torch.Tensor,
device: str,
) -> None:
) -> None | dict[str, np.ndarray] | list[dict[str, np.ndarray]]:
"""Run inference on an input batch.

Contains logic for forward operation as well as I/O aggregation.

Args:
model (nn.Module):
PyTorch defined model.
batch_data (np.ndarray):
batch_data (torch.Tensor):
A batch of data generated by
`torch.utils.data.DataLoader`.
device (str):
Expand Down Expand Up @@ -227,4 +227,5 @@ def load_weights_from_file(self: ModelABC, weights: str | Path) -> torch.nn.Modu
# ! assume to be saved in single GPU mode
# always load on to the CPU
saved_state_dict = torch.load(weights, map_location="cpu")
return super().load_state_dict(saved_state_dict, strict=True)
super().load_state_dict(saved_state_dict, strict=True)
return self
Loading