ocr_post_correction/MiBioDs.py

102 lines
3.2 KiB
Python
Raw Normal View History

2023-07-18 11:51:21 +02:00
import os
from torch.utils.data import Dataset
import nltk
nltk.download("punkt")
from nltk.tokenize import PunktSentenceTokenizer
from typing import List, Tuple
import re
import langid
#langid.set_languages(["de", "sl"]) # set languages for langid
langid.set_languages(["en", "fr"]) # set languages for langid (in this case just the english language)
"""
General pipeline:
- Load OCR and GT
- Tokenize by sentences
- Remove whitespaces
- Filter entities (we don't need that part, there are no scientific entities in the dataset)
- Identifying and masking the incorrect words
"""
class Sentence:
def __init__(self, text: str):
self.text = text
self.lang = self._recognize_lang()
def _recognize_lang(self):
return langid.classify(self.text)[0]
def remove_whitespaces(self):
self.text = self.text.replace(" ", "")
return self
def remove_special_symbols(self):
self.text = re.sub(r"[^a-zA-Z0-9]+", "", self.text)
return self
@property
def _text(self):
return self.text
class MiBioDs(Dataset):
def __init__(
self, root_ocr="../MiBio-OCR-dataset/ocr/", root_gt="../MiBio-OCR-dataset/gt/"
):
self.root_ocr = root_ocr
self.root_gt = root_gt
self.tokenizer = PunktSentenceTokenizer() # tokenize by sentances
# Get sentences
self.ocr_sentences = self._get_data(self.root_ocr)
self.gt_sentences = self._get_data(self.root_gt)
# remove whitespaces
# self.ocr_sentences = self._remove_whitespaces(self.ocr_sentences)
# self.gt_sentences = self._remove_whitespaces(self.gt_sentences)
print(f"Number of sentences in OCR: {len(self.ocr_sentences)}")
print(f"Number of sentences in GT: {len(self.gt_sentences)}")
def _get_data(self, path) -> List[str]:
all_sentences = []
for file in os.listdir(path):
if file.endswith(".txt"):
with open(self.root_gt + file, "r") as f:
file_contents = f.read()
file_contents = file_contents.replace("\n", " ")
sentences = self.tokenizer.tokenize(file_contents)
for s in sentences:
if "ocr" in path: # this is a input dataset
all_sentences.append(Sentence(text=s).remove_whitespaces().remove_special_symbols())
if "gt" in path: # this is a ground truth dataset
all_sentences.append(Sentence(text=s))
return all_sentences
# def _remove_whitespaces(self, sentences: List[str]) -> List[str]:
# return [sentence.replace(" ", "") for sentence in sentences]
# def _remove_special_symbols(self, sentences: List[str]) -> List[str]:
# return [re.sub(r"[^a-zA-Z0-9]+", "", sentence) for sentence in sentences]
def __len__(self):
return len(self.ocr_sentences)
def __getitem__(self, idx) -> Tuple[str, str]:
return self.ocr_sentences[idx]._text, self.gt_sentences[idx]._text
if __name__ == "__main__":
ds = MiBioDs()
from torch.utils.data import DataLoader
dl = DataLoader(ds, batch_size=1, shuffle=True)
for ocr, gt in dl:
print(ocr, gt)
print("------")