102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
|
import os
|
||
|
from torch.utils.data import Dataset
|
||
|
import nltk
|
||
|
|
||
|
nltk.download("punkt")
|
||
|
from nltk.tokenize import PunktSentenceTokenizer
|
||
|
from typing import List, Tuple
|
||
|
import re
|
||
|
import langid
|
||
|
|
||
|
|
||
|
#langid.set_languages(["de", "sl"]) # set languages for langid
|
||
|
langid.set_languages(["en", "fr"]) # set languages for langid (in this case just the english language)
|
||
|
"""
|
||
|
General pipeline:
|
||
|
- Load OCR and GT
|
||
|
- Tokenize by sentences
|
||
|
- Remove whitespaces
|
||
|
- Filter entities (we don't need that part, there are no scientific entities in the dataset)
|
||
|
- Identifying and masking the incorrect words
|
||
|
"""
|
||
|
|
||
|
|
||
|
class Sentence:
|
||
|
def __init__(self, text: str):
|
||
|
self.text = text
|
||
|
self.lang = self._recognize_lang()
|
||
|
|
||
|
def _recognize_lang(self):
|
||
|
return langid.classify(self.text)[0]
|
||
|
|
||
|
def remove_whitespaces(self):
|
||
|
self.text = self.text.replace(" ", "")
|
||
|
return self
|
||
|
|
||
|
def remove_special_symbols(self):
|
||
|
self.text = re.sub(r"[^a-zA-Z0-9]+", "", self.text)
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
def _text(self):
|
||
|
return self.text
|
||
|
|
||
|
|
||
|
class MiBioDs(Dataset):
|
||
|
def __init__(
|
||
|
self, root_ocr="../MiBio-OCR-dataset/ocr/", root_gt="../MiBio-OCR-dataset/gt/"
|
||
|
):
|
||
|
self.root_ocr = root_ocr
|
||
|
self.root_gt = root_gt
|
||
|
self.tokenizer = PunktSentenceTokenizer() # tokenize by sentances
|
||
|
|
||
|
# Get sentences
|
||
|
self.ocr_sentences = self._get_data(self.root_ocr)
|
||
|
self.gt_sentences = self._get_data(self.root_gt)
|
||
|
# remove whitespaces
|
||
|
|
||
|
# self.ocr_sentences = self._remove_whitespaces(self.ocr_sentences)
|
||
|
# self.gt_sentences = self._remove_whitespaces(self.gt_sentences)
|
||
|
|
||
|
print(f"Number of sentences in OCR: {len(self.ocr_sentences)}")
|
||
|
print(f"Number of sentences in GT: {len(self.gt_sentences)}")
|
||
|
|
||
|
def _get_data(self, path) -> List[str]:
|
||
|
all_sentences = []
|
||
|
for file in os.listdir(path):
|
||
|
if file.endswith(".txt"):
|
||
|
with open(self.root_gt + file, "r") as f:
|
||
|
file_contents = f.read()
|
||
|
file_contents = file_contents.replace("\n", " ")
|
||
|
sentences = self.tokenizer.tokenize(file_contents)
|
||
|
for s in sentences:
|
||
|
if "ocr" in path: # this is a input dataset
|
||
|
all_sentences.append(Sentence(text=s).remove_whitespaces().remove_special_symbols())
|
||
|
if "gt" in path: # this is a ground truth dataset
|
||
|
all_sentences.append(Sentence(text=s))
|
||
|
|
||
|
return all_sentences
|
||
|
|
||
|
# def _remove_whitespaces(self, sentences: List[str]) -> List[str]:
|
||
|
# return [sentence.replace(" ", "") for sentence in sentences]
|
||
|
|
||
|
# def _remove_special_symbols(self, sentences: List[str]) -> List[str]:
|
||
|
# return [re.sub(r"[^a-zA-Z0-9]+", "", sentence) for sentence in sentences]
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.ocr_sentences)
|
||
|
|
||
|
def __getitem__(self, idx) -> Tuple[str, str]:
|
||
|
return self.ocr_sentences[idx]._text, self.gt_sentences[idx]._text
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
ds = MiBioDs()
|
||
|
from torch.utils.data import DataLoader
|
||
|
|
||
|
dl = DataLoader(ds, batch_size=1, shuffle=True)
|
||
|
|
||
|
for ocr, gt in dl:
|
||
|
print(ocr, gt)
|
||
|
print("------")
|