From a89c1b812f86e469f4bf4bbf5bb1059a6a1ca39c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C5=A1per=20Spagnolo?= Date: Wed, 26 Jul 2023 11:32:44 +0200 Subject: [PATCH] Push --- .gitignore | 1 + dataloader.py | 78 +++++++++++++++++++++++++++++++++++++++++++++++++++ ft.py | 1 + 3 files changed, 80 insertions(+) create mode 100644 .gitignore create mode 100644 dataloader.py create mode 100644 ft.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..21d0b89 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.venv/ diff --git a/dataloader.py b/dataloader.py new file mode 100644 index 0000000..391d4e0 --- /dev/null +++ b/dataloader.py @@ -0,0 +1,78 @@ +import tensorflow as tf +import nlpaug.augmenter.word as naw + + +class DataLoader: + def __init__(self, path, buffer_size, batch_size, max_length, test_ratio=0.2): + self.path = path + self.buffer_size = buffer_size + self.batch_size = batch_size + self.max_length = max_length + self.test_ratio = test_ratio + self.aug = naw.SynonymAug(aug_src="wordnet") + + def _split_input_target(self, sequence): + parts = tf.strings.split(sequence, "\t") + index = int(parts[0]) + sentence = tf.strings.reduce_join(parts[1:], separator=" ") + return sentence, index + + def augment_data(self, sentence, index): + aug_sentence = self.aug.augment(sentence.numpy().decode()) + return sentence, aug_sentence, index + + def tf_augment_data(self, sentence, index): + sentence, aug_sentence, index = tf.py_function( + self.augment_data, [sentence, index], [tf.string, tf.string, tf.int32] + ) + return sentence, aug_sentence, index + + def load_dataset(self): + lines_dataset = tf.data.TextLineDataset(self.path) + dataset = lines_dataset.map(self._split_input_target) + dataset = dataset.map(self.tf_augment_data) + + # Split dataset into train and test + dataset_size = tf.data.experimental.cardinality(dataset).numpy() + test_size = int(dataset_size * self.test_ratio) + train_size = dataset_size - test_size + train_dataset = dataset.take(train_size) + test_dataset = dataset.skip(train_size) + + # Shuffle and batch + train_dataset = train_dataset.shuffle(self.buffer_size).batch(self.batch_size) + test_dataset = test_dataset.shuffle(self.buffer_size).batch(self.batch_size) + + return train_dataset, test_dataset + + +def test(): + # Hyperparameters + buffer_size = 10000 + batch_size = 64 + max_length = 100 # Or any other value depending on your data + + # Create DataLoader + data_loader = DataLoader( + "../datasets/deu_mixed-typical_2011_1M/deu_mixed-typical_2011_1M-sentences.txt", + buffer_size, + batch_size, + max_length, + ) + + # Load the datasets + train_dataset, test_dataset = data_loader.load_dataset() + + # Test the data loader on the training dataset + print("First 5 batches from the training dataset:") + for sent, aug, indxs in train_dataset.take(1): + print(f"Indices: {indxs}, Sentences: {sent}, Augmented: {aug}") + + # Test the data loader on the test dataset + # print("\nFirst 5 batches from the test dataset:") + # for sentences, indices in test_dataset.take(5): + # print(f"Indices: {indices}, Sentences: {sentences}") + + +if __name__ == "__main__": + test() diff --git a/ft.py b/ft.py new file mode 100644 index 0000000..21ab5da --- /dev/null +++ b/ft.py @@ -0,0 +1 @@ +#! /usr/bin/env python3