xml-ocr/parse.py

185 lines
6.8 KiB
Python
Raw Normal View History

2023-07-05 00:56:05 +02:00
import os
import xml.etree.ElementTree as ET
from xml.dom import minidom
from transformers import XLMRobertaTokenizer, XLMRobertaForMaskedLM
import torch
from tqdm import tqdm
import argparse
from typing import List
import warnings
2023-08-03 13:36:16 +02:00
import re
2023-07-05 00:56:05 +02:00
class XMLProcessor:
def __init__(self, args: argparse.Namespace):
self.input_dir: str = args.input_dir
self.output_dir: str = args.output_dir
self.use_cpu: bool = args.cpu
self.model: XLMRobertaForMaskedLM = None
self.tokenizer: XLMRobertaTokenizer = None
self.device: torch.device = None
2023-08-03 13:36:16 +02:00
self.documents: dict = {}
2023-07-05 00:56:05 +02:00
def prepare_model(self) -> None:
model_name: str = "xlm-roberta-base"
self.model = XLMRobertaForMaskedLM.from_pretrained(model_name)
self.device = (
torch.device("cpu")
if self.use_cpu
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
self.model = self.model.to(self.device)
self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_name)
2023-08-03 13:36:16 +02:00
def contains_number(self, string):
# Regular expression to match any number
2023-08-03 15:52:38 +02:00
pattern = r"\d+"
2023-08-03 13:36:16 +02:00
# Search for the pattern in the string
match = re.search(pattern, string)
# Return True if a number is found, otherwise False
return bool(match)
2023-07-05 00:56:05 +02:00
def compute_perplexity(self, text: str) -> float:
max_length = 512 # XLM-Roberta's maximum sequence length
total_loss = 0.0
total_count = 0
if text is None or len(text) == 0:
warnings.warn("Empty text")
return 0.0
2023-08-04 09:51:51 +02:00
if text is not None and len(text) < 20 and self.contains_number(text):
2023-08-03 15:52:38 +02:00
warnings.warn("Text contains number")
return 0.0
2023-07-05 00:56:05 +02:00
while len(text) > max_length:
split_index = text[:max_length].rfind(" ")
if split_index == -1: # If there's no space to split on, just truncate
split_index = max_length
sentence = text[:split_index].strip()
text = text[split_index:].strip()
inputs = self.tokenizer(
sentence, truncation=True, max_length=max_length, return_tensors="pt"
)
inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()}
outputs = self.model(**inputs, labels=inputs["input_ids"])
2023-08-03 13:36:16 +02:00
total_loss += outputs.loss.item()
2023-07-05 00:56:05 +02:00
total_count += 1
# Process the remaining text
if len(text) > 0:
inputs = self.tokenizer(
text, truncation=True, max_length=max_length, return_tensors="pt"
)
inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()}
outputs = self.model(**inputs, labels=inputs["input_ids"])
total_loss += torch.exp(outputs.loss).item()
total_count += 1
average_loss = total_loss / total_count if total_count > 0 else 0.0
return average_loss
def find_xmls_recursively(self, input_dir: str, xmls: List[str]) -> List[str]:
for entry in os.scandir(input_dir):
if entry.is_file() and entry.name.endswith(".xml"):
xmls.append(entry.path)
elif entry.is_dir():
self.find_xmls_recursively(entry.path, xmls)
return xmls
def find_xmls(self) -> List[str]:
return self.find_xmls_recursively(self.input_dir, [])
def process_xml_file(self, xml_path: str, save_path: str) -> None:
tree = ET.parse(xml_path)
root = tree.getroot()
ns = {"default": "http://www.tei-c.org/ns/1.0"}
document_perplexity = 0
num_segments = 0
for note in root.findall(".//default:note", ns):
sentence = note.text
perplexity = self.compute_perplexity(sentence)
note.set("perplexity", str(perplexity))
document_perplexity += perplexity
num_segments += 1
for seg in root.findall(".//default:seg", ns):
segment_perplexity = 0
num_sentences = 0
for s in seg.findall(".//default:s", ns):
2023-08-03 15:52:38 +02:00
sentence_parts = []
for child in s:
if child.tag == "{" + ns["default"] + "}w":
sentence_parts.append(child.text)
elif child.tag == "{" + ns["default"] + "}pc":
if child.text in {".", ",", ";", ":"}:
if sentence_parts:
sentence_parts[-1] = sentence_parts[-1] + child.text
else:
sentence_parts.append(child.text)
else:
sentence_parts.append(child.text + " ")
sentence = " ".join(sentence_parts)
2023-07-05 00:56:05 +02:00
perplexity = self.compute_perplexity(sentence)
2023-08-03 15:52:38 +02:00
if (
perplexity == 0.0
): # If the sentence is empty, or contains a number len(snt) < 15, we skip it
continue
2023-07-05 00:56:05 +02:00
s.set("perplexity", str(perplexity))
segment_perplexity += perplexity
num_sentences += 1
if num_sentences != 0:
seg.set("perplexity", str(segment_perplexity / num_sentences))
document_perplexity += segment_perplexity / num_sentences
num_segments += 1
if num_segments != 0:
root.set("perplexity", str(document_perplexity / num_segments))
2023-07-17 11:08:19 +02:00
self.documents[xml_path] = float(document_perplexity / num_segments)
2023-07-05 00:56:05 +02:00
xml_str = ET.tostring(root, encoding="unicode", method="xml")
dom = minidom.parseString(xml_str)
pretty_xml_str = dom.toprettyxml(indent=" ", newl="\n")
pretty_xml_str = os.linesep.join(
[s for s in pretty_xml_str.splitlines() if s.strip()]
)
with open(
save_path + "/" + os.path.basename(xml_path), "w", encoding="utf-8"
) as xml_file:
xml_file.write(pretty_xml_str)
def main(self) -> None:
self.prepare_model()
xmls = self.find_xmls()
os.makedirs(self.output_dir, exist_ok=True)
for xml in tqdm(xmls):
if os.path.exists(self.output_dir + "/" + os.path.basename(xml)):
continue
self.process_xml_file(xml, self.output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process XML files.")
parser.add_argument(
"--input_dir", type=str, help="Input directory containing XML files."
)
parser.add_argument(
"--output_dir", type=str, help="Output directory to save processed files."
)
parser.add_argument("--cpu", action="store_true", help="Force usage of CPU.")
args = parser.parse_args()
processor = XMLProcessor(args)
processor.main()