fine-tune/dl.py

76 lines
2.3 KiB
Python

#! /usr/bin/env python3
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from nlpaug.augmenter.char import OcrAug
from nlpaug.augmenter.word import RandomWordAug
from sklearn.model_selection import train_test_split
import pandas as pd
class TextDataset(Dataset):
def __init__(self, path, max_length, buffer_size):
self.data = pd.read_csv(path, delimiter="\t", header=None)
self.max_length = max_length
self.buffer_size = buffer_size
# Augmentations
self.aug_char = OcrAug(
name="OCR_Aug",
aug_char_min=2,
aug_char_max=10,
aug_char_p=0.3,
aug_word_p=0.3,
aug_word_min=1,
aug_word_max=10,
)
self.aug_delete = RandomWordAug(
action="delete", name="RandomWord_Aug", aug_min=0, aug_max=1, aug_p=0.1
)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
index, sentence = self.data.iloc[idx]
aug_sentence = self.aug_char.augment(sentence)
aug_sentence = self.aug_delete.augment(aug_sentence)
aug_sentence = aug_sentence[0]
return sentence, aug_sentence
def load_dataset(path, max_length, buffer_size, batch_size, test_ratio=0.2):
# Create dataset
dataset = TextDataset(path, max_length, buffer_size)
# Calculate split sizes
total_size = len(dataset)
test_size = int(total_size * test_ratio)
train_size = total_size - test_size
# Split dataset into train and test
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
return train_dataloader, test_dataloader
def test():
train_dataloader, test_dataloader = load_dataset(
"../datasets/deu_mixed-typical_2011_1M/deu_mixed-typical_2011_1M-sentences.txt",
100,
100,
32,
test_ratio=0.2,
)
for batch in train_dataloader:
for sentence, aug_sentence in zip(batch[0], batch[1]):
print(f"sentence: {sentence} | aug_sentence: {aug_sentence}")
if __name__ == "__main__":
test()