76 lines
2.3 KiB
Python
76 lines
2.3 KiB
Python
#! /usr/bin/env python3
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset, DataLoader, random_split
|
|
from nlpaug.augmenter.char import OcrAug
|
|
from nlpaug.augmenter.word import RandomWordAug
|
|
from sklearn.model_selection import train_test_split
|
|
import pandas as pd
|
|
|
|
|
|
class TextDataset(Dataset):
|
|
def __init__(self, path, max_length, buffer_size):
|
|
self.data = pd.read_csv(path, delimiter="\t", header=None)
|
|
self.max_length = max_length
|
|
self.buffer_size = buffer_size
|
|
|
|
# Augmentations
|
|
self.aug_char = OcrAug(
|
|
name="OCR_Aug",
|
|
aug_char_min=2,
|
|
aug_char_max=10,
|
|
aug_char_p=0.3,
|
|
aug_word_p=0.3,
|
|
aug_word_min=1,
|
|
aug_word_max=10,
|
|
)
|
|
self.aug_delete = RandomWordAug(
|
|
action="delete", name="RandomWord_Aug", aug_min=0, aug_max=1, aug_p=0.1
|
|
)
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
index, sentence = self.data.iloc[idx]
|
|
aug_sentence = self.aug_char.augment(sentence)
|
|
aug_sentence = self.aug_delete.augment(aug_sentence)
|
|
aug_sentence = aug_sentence[0]
|
|
return sentence, aug_sentence
|
|
|
|
|
|
def load_dataset(path, max_length, buffer_size, batch_size, test_ratio=0.2):
|
|
# Create dataset
|
|
dataset = TextDataset(path, max_length, buffer_size)
|
|
|
|
# Calculate split sizes
|
|
total_size = len(dataset)
|
|
test_size = int(total_size * test_ratio)
|
|
train_size = total_size - test_size
|
|
|
|
# Split dataset into train and test
|
|
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
|
|
|
|
# Create dataloaders
|
|
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
|
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
return train_dataloader, test_dataloader
|
|
|
|
|
|
def test():
|
|
train_dataloader, test_dataloader = load_dataset(
|
|
"../datasets/deu_mixed-typical_2011_1M/deu_mixed-typical_2011_1M-sentences.txt",
|
|
100,
|
|
100,
|
|
32,
|
|
test_ratio=0.2,
|
|
)
|
|
for batch in train_dataloader:
|
|
for sentence, aug_sentence in zip(batch[0], batch[1]):
|
|
print(f"sentence: {sentence} | aug_sentence: {aug_sentence}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test()
|