LEts use multilingual model

main
Gašper Spagnolo 2023-08-01 13:50:29 +02:00
parent aba634b82d
commit 70d48eb412
No known key found for this signature in database
GPG Key ID: 2EA0738CC1EFEEB7
3 changed files with 31 additions and 17 deletions

31
ft.py
View File

@ -1,5 +1,6 @@
#! /usr/bin/env python3
from transformers import BartForConditionalGeneration, BartTokenizer, AdamW
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch
from dl import load_dataset
from tqdm import tqdm
@ -8,18 +9,23 @@ import os
class FT:
def __init__(self):
# Enable cudnn optimizations
torch.backends.cudnn.benchmark = True
self.model_name = "facebook/mbart-large-50"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load tokenizer and model
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(
self.model_name, load_in_8bit=True
)
# set up optimizer
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
self.scheduler = ReduceLROnPlateau(
self.optimizer, mode="min", patience=1, factor=0.5
) # add this line
try:
from torch.cuda.amp import GradScaler, autocast
@ -122,12 +128,13 @@ class FT:
def train(self):
train_dataloader, test_dataloader = load_dataset(
"./dataset/hason_out.json",
100,
100,
32,
200,
200,
1,
test_ratio=0.2,
)
num_epochs = 3
num_epochs = 100
last_lr = None
for epoch in range(num_epochs):
avg_train_loss = self.train_model(train_dataloader)
print(f"Train loss for epoch {epoch+1}: {avg_train_loss}")
@ -137,6 +144,12 @@ class FT:
avg_test_loss = self.test_model(test_dataloader)
print(f"Test loss for epoch {epoch+1}: {avg_test_loss}")
# Check if the learning rate has changed
current_lr = self.optimizer.param_groups[0]["lr"]
if last_lr and current_lr != last_lr:
print(f"Learning rate reduced from {last_lr} to {current_lr}")
last_lr = current_lr
if __name__ == "__main__":
trainer = FT()

9
run.py
View File

@ -6,6 +6,7 @@ import os
from torch.utils.data import Dataset
import pandas as pd
class DL(Dataset):
def __init__(self, path, max_length, buffer_size):
self.data = pd.read_csv(path, delimiter="\t", header=None)
@ -35,7 +36,6 @@ class FT:
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
self.load_checkpoint()
try:
from torch.cuda.amp import GradScaler, autocast
@ -72,13 +72,15 @@ class FT:
num_beams=4, # Number of beams for beam search (optional)
early_stopping=True, # Stop generation when all beams are finished (optional)
)
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_text = self.tokenizer.decode(
outputs[0], skip_special_tokens=True
)
predictions.append(generated_text)
return predictions
def load_checkpoint(self):
checkpoint_path = "./checkpoints/ft_0.pt"
checkpoint_path = "./checkpoints/ft_14.pt"
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
self.current_epoch = checkpoint["epoch"]
@ -100,7 +102,6 @@ class FT:
print(outputs)
if __name__ == "__main__":
validator = FT()
validator.validate()