| import os |
|
|
| import pandas as pd |
| import torch, time, wandb |
| from collections import defaultdict |
| import pytorch_lightning as pl |
| import numpy as np |
| import pdb |
| from utils.log import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
| class GeneralModule(pl.LightningModule): |
| def __init__(self, args): |
| super().__init__() |
| self.save_hyperparameters() |
| self.args = args |
|
|
| self.iter_step = -1 |
| self._log = defaultdict(list) |
| self.generator = np.random.default_rng() |
| self.last_log_time = time.time() |
|
|
|
|
| def try_print_log(self): |
|
|
| step = self.iter_step if self.args.validate else self.trainer.global_step |
| if (step + 1) % self.args.print_freq == 0: |
| print(os.environ["MODEL_DIR"]) |
| log = self._log |
| log = {key: log[key] for key in log if "iter_" in key} |
|
|
| log = self.gather_log(log, self.trainer.world_size) |
| mean_log = self.get_log_mean(log) |
| mean_log.update( |
| {'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) |
| if self.trainer.is_global_zero: |
| print(str(mean_log)) |
| self.log_dict(mean_log, batch_size=1) |
| if self.args.wandb: |
| wandb.log(mean_log) |
| for key in list(log.keys()): |
| if "iter_" in key: |
| del self._log[key] |
|
|
| def lg(self, key, data): |
| if isinstance(data, torch.Tensor): |
| data = data.detach().cpu().item() |
| log = self._log |
| |
| if self.args.validate or self.stage == 'train': |
| log["iter_" + key].append(data) |
| log[self.stage + "_" + key].append(data) |
|
|
| def on_train_epoch_end(self): |
| log = self._log |
| log = {key: log[key] for key in log if "train_" in key} |
| log = self.gather_log(log, self.trainer.world_size) |
| mean_log = self.get_log_mean(log) |
| mean_log.update( |
| {'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) |
|
|
| if self.trainer.is_global_zero: |
| logger.info(str(mean_log)) |
| self.log_dict(mean_log, batch_size=1) |
| if self.args.wandb: |
| wandb.log(mean_log) |
|
|
| for key in list(log.keys()): |
| if "train_" in key: |
| del self._log[key] |
|
|
| def on_validation_epoch_end(self): |
| self.generator = np.random.default_rng() |
| log = self._log |
| log = {key: log[key] for key in log if "val_" in key} |
| log = self.gather_log(log, self.trainer.world_size) |
| mean_log = self.get_log_mean(log) |
| mean_log.update( |
| {'epoch': float(self.trainer.current_epoch), 'step': float(self.trainer.global_step), 'iter_step': float(self.iter_step)}) |
|
|
| if self.trainer.is_global_zero: |
| logger.info(str(mean_log)) |
| self.log_dict(mean_log, batch_size=1) |
| if self.args.wandb: |
| wandb.log(mean_log) |
|
|
| path = os.path.join( |
| os.environ["MODEL_DIR"], f"val_{self.trainer.global_step}.csv" |
| ) |
| pd.DataFrame(log).to_csv(path) |
|
|
| for key in list(log.keys()): |
| if "val_" in key: |
| del self._log[key] |
|
|
|
|
|
|
| def gather_log(self, log, world_size): |
| if world_size == 1: |
| return log |
| log_list = [None] * world_size |
| torch.distributed.all_gather_object(log_list, log) |
| log = {key: sum([l[key] for l in log_list], []) for key in log} |
| return log |
|
|
| def get_log_mean(self, log): |
| out = {} |
| for key in log: |
| try: |
| out[key] = np.nanmean(log[key]) |
| except: |
| pass |
| return out |