xml-ocr/parse.py

153 lines
5.5 KiB
Python
Raw Normal View History

2023-07-05 00:56:05 +02:00
import os
import xml.etree.ElementTree as ET
from xml.dom import minidom
from transformers import XLMRobertaTokenizer, XLMRobertaForMaskedLM
import torch
from tqdm import tqdm
import argparse
from typing import List
import warnings
class XMLProcessor:
def __init__(self, args: argparse.Namespace):
self.input_dir: str = args.input_dir
self.output_dir: str = args.output_dir
self.use_cpu: bool = args.cpu
self.model: XLMRobertaForMaskedLM = None
self.tokenizer: XLMRobertaTokenizer = None
self.device: torch.device = None
def prepare_model(self) -> None:
model_name: str = "xlm-roberta-base"
self.model = XLMRobertaForMaskedLM.from_pretrained(model_name)
self.device = (
torch.device("cpu")
if self.use_cpu
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
self.model = self.model.to(self.device)
self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_name)
def compute_perplexity(self, text: str) -> float:
max_length = 512 # XLM-Roberta's maximum sequence length
total_loss = 0.0
total_count = 0
if text is None or len(text) == 0:
warnings.warn("Empty text")
return 0.0
while len(text) > max_length:
split_index = text[:max_length].rfind(" ")
if split_index == -1: # If there's no space to split on, just truncate
split_index = max_length
sentence = text[:split_index].strip()
text = text[split_index:].strip()
inputs = self.tokenizer(
sentence, truncation=True, max_length=max_length, return_tensors="pt"
)
inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()}
outputs = self.model(**inputs, labels=inputs["input_ids"])
total_loss += torch.exp(outputs.loss).item()
total_count += 1
# Process the remaining text
if len(text) > 0:
inputs = self.tokenizer(
text, truncation=True, max_length=max_length, return_tensors="pt"
)
inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()}
outputs = self.model(**inputs, labels=inputs["input_ids"])
total_loss += torch.exp(outputs.loss).item()
total_count += 1
average_loss = total_loss / total_count if total_count > 0 else 0.0
return average_loss
def find_xmls_recursively(self, input_dir: str, xmls: List[str]) -> List[str]:
for entry in os.scandir(input_dir):
if entry.is_file() and entry.name.endswith(".xml"):
xmls.append(entry.path)
elif entry.is_dir():
self.find_xmls_recursively(entry.path, xmls)
return xmls
def find_xmls(self) -> List[str]:
return self.find_xmls_recursively(self.input_dir, [])
def process_xml_file(self, xml_path: str, save_path: str) -> None:
tree = ET.parse(xml_path)
root = tree.getroot()
ns = {"default": "http://www.tei-c.org/ns/1.0"}
document_perplexity = 0
num_segments = 0
for note in root.findall(".//default:note", ns):
sentence = note.text
perplexity = self.compute_perplexity(sentence)
note.set("perplexity", str(perplexity))
document_perplexity += perplexity
num_segments += 1
for seg in root.findall(".//default:seg", ns):
segment_perplexity = 0
num_sentences = 0
for s in seg.findall(".//default:s", ns):
sentence = " ".join([w.text for w in s.findall(".//default:w", ns)])
perplexity = self.compute_perplexity(sentence)
s.set("perplexity", str(perplexity))
segment_perplexity += perplexity
num_sentences += 1
if num_sentences != 0:
seg.set("perplexity", str(segment_perplexity / num_sentences))
document_perplexity += segment_perplexity / num_sentences
num_segments += 1
if num_segments != 0:
root.set("perplexity", str(document_perplexity / num_segments))
2023-07-17 11:08:19 +02:00
self.documents[xml_path] = float(document_perplexity / num_segments)
2023-07-05 00:56:05 +02:00
xml_str = ET.tostring(root, encoding="unicode", method="xml")
dom = minidom.parseString(xml_str)
pretty_xml_str = dom.toprettyxml(indent=" ", newl="\n")
pretty_xml_str = os.linesep.join(
[s for s in pretty_xml_str.splitlines() if s.strip()]
)
with open(
save_path + "/" + os.path.basename(xml_path), "w", encoding="utf-8"
) as xml_file:
xml_file.write(pretty_xml_str)
def main(self) -> None:
self.prepare_model()
xmls = self.find_xmls()
os.makedirs(self.output_dir, exist_ok=True)
for xml in tqdm(xmls):
if os.path.exists(self.output_dir + "/" + os.path.basename(xml)):
continue
self.process_xml_file(xml, self.output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process XML files.")
parser.add_argument(
"--input_dir", type=str, help="Input directory containing XML files."
)
parser.add_argument(
"--output_dir", type=str, help="Output directory to save processed files."
)
parser.add_argument("--cpu", action="store_true", help="Force usage of CPU.")
args = parser.parse_args()
processor = XMLProcessor(args)
processor.main()