#! /usr/bin/env python3 from transformers import BartForConditionalGeneration, BartTokenizer, AdamW import torch from tqdm import tqdm import os from torch.utils.data import Dataset import pandas as pd class DL(Dataset): def __init__(self, path, max_length, buffer_size): self.data = pd.read_csv(path, delimiter="\t", header=None) self.max_length = max_length self.buffer_size = buffer_size def __len__(self): return len(self.data) def __getitem__(self, idx): ix, text = self.data.iloc[idx] return text class FT: def __init__(self): # Enable cudnn optimizations torch.backends.cudnn.benchmark = True self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # load tokenizer and model self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") self.model.to(self.device) # set up optimizer self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5) self.load_checkpoint() try: from torch.cuda.amp import GradScaler, autocast self.scaler = GradScaler() except ImportError: class autocast: def __enter__(self): pass def __exit__(self, *args): pass self.scaler = None # We won't use a scaler if we don't have Amp def test_model(self, dataloader): self.model.eval() predictions = [] # List to store the generated text outputs print("Testing model...") for batch in tqdm(dataloader): with torch.no_grad(): inputs = self.tokenizer( batch, return_tensors="pt", padding=True, truncation=True, max_length=512, ) inputs.to(self.device) outputs = self.model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=200, # Set the maximum length of the generated output num_beams=4, # Number of beams for beam search (optional) early_stopping=True, # Stop generation when all beams are finished (optional) ) generated_text = self.tokenizer.decode( outputs[0], skip_special_tokens=True ) predictions.append(generated_text) return predictions def load_checkpoint(self): checkpoint_path = "./checkpoints/ft_14.pt" if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path) self.current_epoch = checkpoint["epoch"] self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) print(f"Loaded checkpoint from epoch {self.current_epoch}") def validate(self): dataloader = torch.utils.data.DataLoader( DL( path="./dataset/gt.txt", max_length=300, buffer_size=100, ), batch_size=1, ) outputs = self.test_model(dataloader) print(outputs) if __name__ == "__main__": validator = FT() validator.validate()