| | from __future__ import annotations |
| |
|
| | import logging |
| | import os |
| | from collections.abc import Hashable, Mapping |
| | from typing import Any, Callable, Sequence |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn |
| | from ignite.engine import Engine |
| | from ignite.metrics import Metric |
| | from monai.config import KeysCollection |
| | from monai.engines import SupervisedTrainer |
| | from monai.engines.utils import get_devices_spec |
| | from monai.inferers import Inferer |
| | from monai.transforms.transform import MapTransform, Transform |
| | from torch.nn.parallel import DataParallel, DistributedDataParallel |
| | from torch.optim.optimizer import Optimizer |
| |
|
| | |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| |
|
| |
|
| | def get_device_list(n_gpu): |
| | if type(n_gpu) is not list: |
| | n_gpu = [n_gpu] |
| | device_list = get_devices_spec(n_gpu) |
| | if torch.cuda.device_count() >= max(n_gpu): |
| | device_list = [d for d in device_list if d in n_gpu] |
| | else: |
| | logging.info( |
| | """Highest GPU ID provided in 'n_gpu' is larger than number of GPUs available, assigning GPUs starting from 0 |
| | to match n_gpu length of {}""".format( |
| | len(n_gpu) |
| | ) |
| | ) |
| | device_list = device_list[: len(n_gpu)] |
| | return device_list |
| |
|
| |
|
| | def supervised_trainer_multi_gpu( |
| | max_epochs: int, |
| | train_data_loader, |
| | network: torch.nn.Module, |
| | optimizer: Optimizer, |
| | loss_function: Callable, |
| | device: Sequence[str | torch.device] | None = None, |
| | epoch_length: int | None = None, |
| | non_blocking: bool = False, |
| | iteration_update: Callable[[Engine, Any], Any] | None = None, |
| | inferer: Inferer | None = None, |
| | postprocessing: Transform | None = None, |
| | key_train_metric: dict[str, Metric] | None = None, |
| | additional_metrics: dict[str, Metric] | None = None, |
| | train_handlers: Sequence | None = None, |
| | amp: bool = False, |
| | distributed: bool = False, |
| | ): |
| | devices_ = device |
| | if not device: |
| | devices_ = get_devices_spec(device) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | net = network |
| | if distributed: |
| | if len(devices_) > 1: |
| | raise ValueError(f"for distributed training, `devices` must contain only 1 GPU or CPU, but got {devices_}.") |
| | net = DistributedDataParallel(net, device_ids=devices_) |
| | elif len(devices_) > 1: |
| | net = DataParallel(net, device_ids=devices_) |
| |
|
| | return SupervisedTrainer( |
| | device=devices_[0], |
| | network=net, |
| | optimizer=optimizer, |
| | loss_function=loss_function, |
| | max_epochs=max_epochs, |
| | train_data_loader=train_data_loader, |
| | epoch_length=epoch_length, |
| | non_blocking=non_blocking, |
| | iteration_update=iteration_update, |
| | inferer=inferer, |
| | postprocessing=postprocessing, |
| | key_train_metric=key_train_metric, |
| | additional_metrics=additional_metrics, |
| | train_handlers=train_handlers, |
| | amp=amp, |
| | ) |
| |
|
| |
|
| | class SupervisedTrainerMGPU(SupervisedTrainer): |
| | def __init__( |
| | self, |
| | max_epochs: int, |
| | train_data_loader, |
| | network: torch.nn.Module, |
| | optimizer: Optimizer, |
| | loss_function: Callable, |
| | device: Sequence[str | torch.device] | None = None, |
| | epoch_length: int | None = None, |
| | non_blocking: bool = False, |
| | iteration_update: Callable[[Engine, Any], Any] | None = None, |
| | inferer: Inferer | None = None, |
| | postprocessing: Transform | None = None, |
| | key_train_metric: dict[str, Metric] | None = None, |
| | additional_metrics: dict[str, Metric] | None = None, |
| | train_handlers: Sequence | None = None, |
| | amp: bool = False, |
| | distributed: bool = False, |
| | ): |
| | self.devices_ = device |
| | if not device: |
| | self.devices_ = get_devices_spec(device) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self.net = network |
| | if distributed: |
| | if len(self.devices_) > 1: |
| | raise ValueError( |
| | f"for distributed training, `devices` must contain only 1 GPU or CPU, but got {self.devices_}." |
| | ) |
| | self.net = DistributedDataParallel(self.net, device_ids=self.devices_) |
| | elif len(self.devices_) > 1: |
| | self.net = DataParallel(self.net, device_ids=self.devices_) |
| |
|
| | super().__init__( |
| | device=self.devices_[0], |
| | network=self.net, |
| | optimizer=optimizer, |
| | loss_function=loss_function, |
| | max_epochs=max_epochs, |
| | train_data_loader=train_data_loader, |
| | epoch_length=epoch_length, |
| | non_blocking=non_blocking, |
| | iteration_update=iteration_update, |
| | inferer=inferer, |
| | postprocessing=postprocessing, |
| | key_train_metric=key_train_metric, |
| | additional_metrics=additional_metrics, |
| | train_handlers=train_handlers, |
| | amp=amp, |
| | ) |
| |
|
| |
|
| | class AddLabelNamesd(MapTransform): |
| | def __init__( |
| | self, keys: KeysCollection, label_names: dict[str, int] | None = None, allow_missing_keys: bool = False |
| | ): |
| | """ |
| | Normalize label values according to label names dictionary |
| | |
| | Args: |
| | keys: The ``keys`` parameter will be used to get and set the actual data item to transform |
| | label_names: all label names |
| | """ |
| | super().__init__(keys, allow_missing_keys) |
| |
|
| | self.label_names = label_names or {} |
| |
|
| | def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: |
| | d: dict = dict(data) |
| | d["label_names"] = self.label_names |
| | return d |
| |
|
| |
|
| | class CopyFilenamesd(MapTransform): |
| | def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False): |
| | """ |
| | Copy Filenames for future use |
| | |
| | Args: |
| | keys: The ``keys`` parameter will be used to get and set the actual data item to transform |
| | """ |
| | super().__init__(keys, allow_missing_keys) |
| |
|
| | def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: |
| | d: dict = dict(data) |
| | d["filename"] = os.path.basename(d["label"]) |
| | return d |
| |
|
| |
|
| | class SplitPredsLabeld(MapTransform): |
| | """ |
| | Split preds and labels for individual evaluation |
| | |
| | """ |
| |
|
| | def __call__(self, data: Mapping[Hashable, np.ndarray]) -> dict[Hashable, np.ndarray]: |
| | d: dict = dict(data) |
| | for key in self.key_iterator(d): |
| | if key == "pred": |
| | for idx, (key_label, _) in enumerate(d["label_names"].items()): |
| | if key_label != "background": |
| | d[f"pred_{key_label}"] = d[key][idx, ...][None] |
| | d[f"label_{key_label}"] = d["label"][idx, ...][None] |
| | elif key != "pred": |
| | logger.info("This is only for pred key") |
| | return d |
| |
|