#! /usr/bin/env python3 from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AdamW from torch.optim.lr_scheduler import ReduceLROnPlateau import torch from dl import load_dataset from tqdm import tqdm import os class FT: def __init__(self): torch.backends.cudnn.benchmark = True self.model_name = "facebook/bart-base" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # load tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained( self.model_name ) self.model.to(self.device) # set up optimizer self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5) self.scheduler = ReduceLROnPlateau( self.optimizer, mode="min", patience=1, factor=0.5 ) try: from torch.cuda.amp import GradScaler, autocast self.scaler = GradScaler() except ImportError: class autocast: def __enter__(self): pass def __exit__(self, *args): pass self.scaler = None # We won't use a scaler if we don't have Amp def train_model(self, dataloader): self.model.train() total_loss = 0 print("Training model...") for batch in tqdm(dataloader): self.optimizer.zero_grad() inputs = self.tokenizer( batch[1], return_tensors="pt", padding=True, truncation=True, max_length=512, ) inputs.to(self.device) labels = self.tokenizer( batch[0], return_tensors="pt", padding=True, truncation=True, max_length=512, ) labels.to(self.device) outputs = self.model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], labels=labels["input_ids"], ) loss = outputs.loss loss.backward() self.optimizer.step() total_loss += loss.item() avg_train_loss = total_loss / len(dataloader) return avg_train_loss def test_model(self, dataloader): self.model.eval() total_loss = 0 print("Testing model...") for batch in tqdm(dataloader): with torch.no_grad(): inputs = self.tokenizer( batch[1], return_tensors="pt", padding=True, truncation=True, max_length=512, ) inputs.to(self.device) labels = self.tokenizer( batch[0], return_tensors="pt", padding=True, truncation=True, max_length=512, ) labels.to(self.device) outputs = self.model( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], labels=labels["input_ids"], ) loss = outputs.loss total_loss += loss.item() avg_test_loss = total_loss / len(dataloader) return avg_test_loss def save_checkpoint(self, epoch): os.makedirs("./checkpoints/ft", exist_ok=True) torch.save( { "epoch": epoch, "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), }, f"./checkpoints/ft/ft_{epoch}.pt", ) def train(self): train_dataloader, test_dataloader = load_dataset( "./dataset/hason_out.json", 200, 200, 1, test_ratio=0.2, ) num_epochs = 100 last_lr = None for epoch in range(num_epochs): avg_train_loss = self.train_model(train_dataloader) print(f"Train loss for epoch {epoch+1}: {avg_train_loss}") self.save_checkpoint(epoch) print("Checkpoint saved!") avg_test_loss = self.test_model(test_dataloader) print(f"Test loss for epoch {epoch+1}: {avg_test_loss}") self.scheduler.step(avg_test_loss) # Check if the learning rate has changed current_lr = self.optimizer.param_groups[0]["lr"] if last_lr and current_lr != last_lr: print(f"Learning rate reduced from {last_lr} to {current_lr}") last_lr = current_lr if __name__ == "__main__": trainer = FT() trainer.train()