LEts use multilingual model
parent
aba634b82d
commit
70d48eb412
6
dl.py
6
dl.py
|
@ -36,9 +36,9 @@ class TextDataset(Dataset):
|
|||
|
||||
def __getitem__(self, idx):
|
||||
sentence, corrected = self.data.iloc[idx]
|
||||
#aug_sentence = self.aug_char.augment(sentence)
|
||||
#aug_sentence = self.aug_delete.augment(aug_sentence)
|
||||
#aug_sentence = aug_sentence[0]
|
||||
# aug_sentence = self.aug_char.augment(sentence)
|
||||
# aug_sentence = self.aug_delete.augment(aug_sentence)
|
||||
# aug_sentence = aug_sentence[0]
|
||||
return corrected, sentence
|
||||
|
||||
|
||||
|
|
31
ft.py
31
ft.py
|
@ -1,5 +1,6 @@
|
|||
#! /usr/bin/env python3
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer, AdamW
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AdamW
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
import torch
|
||||
from dl import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
@ -8,18 +9,23 @@ import os
|
|||
|
||||
class FT:
|
||||
def __init__(self):
|
||||
# Enable cudnn optimizations
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
self.model_name = "facebook/mbart-large-50"
|
||||
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# load tokenizer and model
|
||||
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
||||
self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
|
||||
self.model.to(self.device)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
self.model_name, load_in_8bit=True
|
||||
)
|
||||
|
||||
# set up optimizer
|
||||
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
|
||||
self.scheduler = ReduceLROnPlateau(
|
||||
self.optimizer, mode="min", patience=1, factor=0.5
|
||||
) # add this line
|
||||
|
||||
try:
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
|
@ -122,12 +128,13 @@ class FT:
|
|||
def train(self):
|
||||
train_dataloader, test_dataloader = load_dataset(
|
||||
"./dataset/hason_out.json",
|
||||
100,
|
||||
100,
|
||||
32,
|
||||
200,
|
||||
200,
|
||||
1,
|
||||
test_ratio=0.2,
|
||||
)
|
||||
num_epochs = 3
|
||||
num_epochs = 100
|
||||
last_lr = None
|
||||
for epoch in range(num_epochs):
|
||||
avg_train_loss = self.train_model(train_dataloader)
|
||||
print(f"Train loss for epoch {epoch+1}: {avg_train_loss}")
|
||||
|
@ -137,6 +144,12 @@ class FT:
|
|||
avg_test_loss = self.test_model(test_dataloader)
|
||||
print(f"Test loss for epoch {epoch+1}: {avg_test_loss}")
|
||||
|
||||
# Check if the learning rate has changed
|
||||
current_lr = self.optimizer.param_groups[0]["lr"]
|
||||
if last_lr and current_lr != last_lr:
|
||||
print(f"Learning rate reduced from {last_lr} to {current_lr}")
|
||||
last_lr = current_lr
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
trainer = FT()
|
||||
|
|
11
run.py
11
run.py
|
@ -6,6 +6,7 @@ import os
|
|||
from torch.utils.data import Dataset
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class DL(Dataset):
|
||||
def __init__(self, path, max_length, buffer_size):
|
||||
self.data = pd.read_csv(path, delimiter="\t", header=None)
|
||||
|
@ -35,7 +36,6 @@ class FT:
|
|||
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
|
||||
self.load_checkpoint()
|
||||
|
||||
|
||||
try:
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
|
||||
|
@ -69,16 +69,18 @@ class FT:
|
|||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
max_length=200, # Set the maximum length of the generated output
|
||||
num_beams=4, # Number of beams for beam search (optional)
|
||||
num_beams=4, # Number of beams for beam search (optional)
|
||||
early_stopping=True, # Stop generation when all beams are finished (optional)
|
||||
)
|
||||
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
generated_text = self.tokenizer.decode(
|
||||
outputs[0], skip_special_tokens=True
|
||||
)
|
||||
predictions.append(generated_text)
|
||||
|
||||
return predictions
|
||||
|
||||
def load_checkpoint(self):
|
||||
checkpoint_path = "./checkpoints/ft_0.pt"
|
||||
checkpoint_path = "./checkpoints/ft_14.pt"
|
||||
if os.path.exists(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
self.current_epoch = checkpoint["epoch"]
|
||||
|
@ -100,7 +102,6 @@ class FT:
|
|||
print(outputs)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
validator = FT()
|
||||
validator.validate()
|
||||
|
|
Loading…
Reference in New Issue