import numpy as np import pandas as pd import matplotlib.pyplot as plt import torch from torch import nn from torch.utils.data import Dataset,DataLoader,TensorDataset
# metrics step_metrics = {"train_"+name:metric_fn(preds, labels).item() for name,metric_fn in train_metrics_dict.items()} step_log = dict({"train_loss":loss.item()},**step_metrics) total_loss += loss.item() step+=1 if i!=len(dl_train)-1: loop.set_postfix(**step_log) else: # 处理最后一个batch epoch_loss = total_loss/step epoch_metrics = {"train_"+name:metric_fn.compute().item() for name,metric_fn in train_metrics_dict.items()} epoch_log = dict({"train_loss":epoch_loss},**epoch_metrics) loop.set_postfix(**epoch_log)
for name,metric_fn in train_metrics_dict.items(): metric_fn.reset() for name, metric in epoch_log.items(): history[name] = history.get(name, []) + [metric]
# 2,validate ------------------------------------------------- net.eval() total_loss,step = 0,0 loop = tqdm(enumerate(dl_val), total =len(dl_val),file = sys.stdout) val_metrics_dict = deepcopy(metrics_dict) with torch.no_grad(): for i, batch in loop:
total_loss += loss.item() step+=1 if i!=len(dl_val)-1: loop.set_postfix(**step_log) else: # 计算整个验证集指标 epoch_loss = (total_loss/step) epoch_metrics = {"val_"+name:metric_fn.compute().item() for name,metric_fn in val_metrics_dict.items()} epoch_log = dict({"val_loss":epoch_loss},**epoch_metrics) loop.set_postfix(**epoch_log)
for name,metric_fn in val_metrics_dict.items(): metric_fn.reset() epoch_log["epoch"] = epoch for name, metric in epoch_log.items(): history[name] = history.get(name, []) + [metric]
# 3,early-stopping ------------------------------------------------- arr_scores = history[monitor] best_score_idx = np.argmax(arr_scores) if mode=="max"else np.argmin(arr_scores) if best_score_idx==len(arr_scores)-1: torch.save(net.state_dict(),ckpt_path) print("<<<<<< reach best {0} : {1} >>>>>>".format(monitor, arr_scores[best_score_idx]),file=sys.stderr) iflen(arr_scores)-best_score_idx>patience: print("<<<<<< {} without improvement in {} epoch, early stopping >>>>>>".format( monitor,patience),file=sys.stderr) break net.load_state_dict(torch.load(ckpt_path,weights_only=True)) dfhistory = pd.DataFrame(history)