LEts use multilingual model
parent
aba634b82d
commit
70d48eb412
31
ft.py
31
ft.py
|
@ -1,5 +1,6 @@
|
||||||
#! /usr/bin/env python3
|
#! /usr/bin/env python3
|
||||||
from transformers import BartForConditionalGeneration, BartTokenizer, AdamW
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AdamW
|
||||||
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
import torch
|
import torch
|
||||||
from dl import load_dataset
|
from dl import load_dataset
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
@ -8,18 +9,23 @@ import os
|
||||||
|
|
||||||
class FT:
|
class FT:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Enable cudnn optimizations
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
self.model_name = "facebook/mbart-large-50"
|
||||||
|
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# load tokenizer and model
|
# load tokenizer and model
|
||||||
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||||
self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||||
self.model.to(self.device)
|
self.model_name, load_in_8bit=True
|
||||||
|
)
|
||||||
|
|
||||||
# set up optimizer
|
# set up optimizer
|
||||||
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
|
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
|
||||||
|
self.scheduler = ReduceLROnPlateau(
|
||||||
|
self.optimizer, mode="min", patience=1, factor=0.5
|
||||||
|
) # add this line
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.cuda.amp import GradScaler, autocast
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
|
@ -122,12 +128,13 @@ class FT:
|
||||||
def train(self):
|
def train(self):
|
||||||
train_dataloader, test_dataloader = load_dataset(
|
train_dataloader, test_dataloader = load_dataset(
|
||||||
"./dataset/hason_out.json",
|
"./dataset/hason_out.json",
|
||||||
100,
|
200,
|
||||||
100,
|
200,
|
||||||
32,
|
1,
|
||||||
test_ratio=0.2,
|
test_ratio=0.2,
|
||||||
)
|
)
|
||||||
num_epochs = 3
|
num_epochs = 100
|
||||||
|
last_lr = None
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
avg_train_loss = self.train_model(train_dataloader)
|
avg_train_loss = self.train_model(train_dataloader)
|
||||||
print(f"Train loss for epoch {epoch+1}: {avg_train_loss}")
|
print(f"Train loss for epoch {epoch+1}: {avg_train_loss}")
|
||||||
|
@ -137,6 +144,12 @@ class FT:
|
||||||
avg_test_loss = self.test_model(test_dataloader)
|
avg_test_loss = self.test_model(test_dataloader)
|
||||||
print(f"Test loss for epoch {epoch+1}: {avg_test_loss}")
|
print(f"Test loss for epoch {epoch+1}: {avg_test_loss}")
|
||||||
|
|
||||||
|
# Check if the learning rate has changed
|
||||||
|
current_lr = self.optimizer.param_groups[0]["lr"]
|
||||||
|
if last_lr and current_lr != last_lr:
|
||||||
|
print(f"Learning rate reduced from {last_lr} to {current_lr}")
|
||||||
|
last_lr = current_lr
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
trainer = FT()
|
trainer = FT()
|
||||||
|
|
9
run.py
9
run.py
|
@ -6,6 +6,7 @@ import os
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
class DL(Dataset):
|
class DL(Dataset):
|
||||||
def __init__(self, path, max_length, buffer_size):
|
def __init__(self, path, max_length, buffer_size):
|
||||||
self.data = pd.read_csv(path, delimiter="\t", header=None)
|
self.data = pd.read_csv(path, delimiter="\t", header=None)
|
||||||
|
@ -35,7 +36,6 @@ class FT:
|
||||||
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
|
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
|
||||||
self.load_checkpoint()
|
self.load_checkpoint()
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.cuda.amp import GradScaler, autocast
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
|
|
||||||
|
@ -72,13 +72,15 @@ class FT:
|
||||||
num_beams=4, # Number of beams for beam search (optional)
|
num_beams=4, # Number of beams for beam search (optional)
|
||||||
early_stopping=True, # Stop generation when all beams are finished (optional)
|
early_stopping=True, # Stop generation when all beams are finished (optional)
|
||||||
)
|
)
|
||||||
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
generated_text = self.tokenizer.decode(
|
||||||
|
outputs[0], skip_special_tokens=True
|
||||||
|
)
|
||||||
predictions.append(generated_text)
|
predictions.append(generated_text)
|
||||||
|
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
def load_checkpoint(self):
|
def load_checkpoint(self):
|
||||||
checkpoint_path = "./checkpoints/ft_0.pt"
|
checkpoint_path = "./checkpoints/ft_14.pt"
|
||||||
if os.path.exists(checkpoint_path):
|
if os.path.exists(checkpoint_path):
|
||||||
checkpoint = torch.load(checkpoint_path)
|
checkpoint = torch.load(checkpoint_path)
|
||||||
self.current_epoch = checkpoint["epoch"]
|
self.current_epoch = checkpoint["epoch"]
|
||||||
|
@ -100,7 +102,6 @@ class FT:
|
||||||
print(outputs)
|
print(outputs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
validator = FT()
|
validator = FT()
|
||||||
validator.validate()
|
validator.validate()
|
||||||
|
|
Loading…
Reference in New Issue