import os from torch.utils.data import Dataset import nltk nltk.download("punkt") from nltk.tokenize import PunktSentenceTokenizer from typing import List, Tuple import re import langid #langid.set_languages(["de", "sl"]) # set languages for langid langid.set_languages(["en", "fr"]) # set languages for langid (in this case just the english language) """ General pipeline: - Load OCR and GT - Tokenize by sentences - Remove whitespaces - Filter entities (we don't need that part, there are no scientific entities in the dataset) - Identifying and masking the incorrect words """ class Sentence: def __init__(self, text: str): self.text = text self.lang = self._recognize_lang() def _recognize_lang(self): return langid.classify(self.text)[0] def remove_whitespaces(self): self.text = self.text.replace(" ", "") return self def remove_special_symbols(self): self.text = re.sub(r"[^a-zA-Z0-9]+", "", self.text) return self @property def _text(self): return self.text class MiBioDs(Dataset): def __init__( self, root_ocr="../MiBio-OCR-dataset/ocr/", root_gt="../MiBio-OCR-dataset/gt/" ): self.root_ocr = root_ocr self.root_gt = root_gt self.tokenizer = PunktSentenceTokenizer() # tokenize by sentances # Get sentences self.ocr_sentences = self._get_data(self.root_ocr) self.gt_sentences = self._get_data(self.root_gt) # remove whitespaces # self.ocr_sentences = self._remove_whitespaces(self.ocr_sentences) # self.gt_sentences = self._remove_whitespaces(self.gt_sentences) print(f"Number of sentences in OCR: {len(self.ocr_sentences)}") print(f"Number of sentences in GT: {len(self.gt_sentences)}") def _get_data(self, path) -> List[str]: all_sentences = [] for file in os.listdir(path): if file.endswith(".txt"): with open(self.root_gt + file, "r") as f: file_contents = f.read() file_contents = file_contents.replace("\n", " ") sentences = self.tokenizer.tokenize(file_contents) for s in sentences: if "ocr" in path: # this is a input dataset all_sentences.append(Sentence(text=s).remove_whitespaces().remove_special_symbols()) if "gt" in path: # this is a ground truth dataset all_sentences.append(Sentence(text=s)) return all_sentences # def _remove_whitespaces(self, sentences: List[str]) -> List[str]: # return [sentence.replace(" ", "") for sentence in sentences] # def _remove_special_symbols(self, sentences: List[str]) -> List[str]: # return [re.sub(r"[^a-zA-Z0-9]+", "", sentence) for sentence in sentences] def __len__(self): return len(self.ocr_sentences) def __getitem__(self, idx) -> Tuple[str, str]: return self.ocr_sentences[idx]._text, self.gt_sentences[idx]._text if __name__ == "__main__": ds = MiBioDs() from torch.utils.data import DataLoader dl = DataLoader(ds, batch_size=1, shuffle=True) for ocr, gt in dl: print(ocr, gt) print("------")