import os
import xml.etree.ElementTree as ET
from xml.dom import minidom
from transformers import XLMRobertaTokenizer, XLMRobertaForMaskedLM
import torch
from tqdm import tqdm
import argparse
from typing import List
import warnings
import re
class XMLProcessor:
def __init__(self, args: argparse.Namespace):
self.input_dir: str = args.input_dir
self.output_dir: str = args.output_dir
self.use_cpu: bool = args.cpu
self.model: XLMRobertaForMaskedLM = None
self.tokenizer: XLMRobertaTokenizer = None
self.device: torch.device = None
self.documents: dict = {}
def prepare_model(self) -> None:
model_name: str = "xlm-roberta-base"
self.model = XLMRobertaForMaskedLM.from_pretrained(model_name)
self.device = (
if self.use_cpu
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_name)
def contains_number(self, string):
# Regular expression to match any number
pattern = r"\d+"
# Search for the pattern in the string
match = re.search(pattern, string)
# Return True if a number is found, otherwise False
return bool(match)
def compute_perplexity(self, text: str) -> float:
max_length = 512 # XLM-Roberta's maximum sequence length
total_loss = 0.0
total_count = 0
if text is None or len(text) == 0:
warnings.warn("Empty text")
return 0.0
if text is not None and len(text) < 20 and self.contains_number(text):
warnings.warn("Text contains number")
return 0.0
while len(text) > max_length:
split_index = text[:max_length].rfind(" ")
if split_index == -1: # If there's no space to split on, just truncate
split_index = max_length
sentence = text[:split_index].strip()
text = text[split_index:].strip()
inputs = self.tokenizer(
sentence, truncation=True, max_length=max_length, return_tensors="pt"
inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()}
outputs = self.model(**inputs, labels=inputs["input_ids"])
total_loss += outputs.loss.item()
total_count += 1
# Process the remaining text
if len(text) > 0:
inputs = self.tokenizer(
text, truncation=True, max_length=max_length, return_tensors="pt"
inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()}
outputs = self.model(**inputs, labels=inputs["input_ids"])
total_loss += torch.exp(outputs.loss).item()
total_count += 1
average_loss = total_loss / total_count if total_count > 0 else 0.0
return average_loss
def find_xmls_recursively(self, input_dir: str, xmls: List[str]) -> List[str]:
for entry in os.scandir(input_dir):
if entry.is_file() and entry.name.endswith(".xml"):
elif entry.is_dir():
self.find_xmls_recursively(entry.path, xmls)
return xmls
def find_xmls(self) -> List[str]:
return self.find_xmls_recursively(self.input_dir, [])
def process_xml_file(self, xml_path: str, save_path: str) -> None:
tree = ET.parse(xml_path)
root = tree.getroot()
ns = {"default": "http://www.tei-c.org/ns/1.0"}
document_perplexity = 0
num_segments = 0
for note in root.findall(".//default:note", ns):
sentence = note.text
perplexity = self.compute_perplexity(sentence)
note.set("perplexity", str(perplexity))
document_perplexity += perplexity
num_segments += 1
for seg in root.findall(".//default:seg", ns):
segment_perplexity = 0
num_sentences = 0
for s in seg.findall(".//default:s", ns):
sentence_parts = []
for child in s:
if child.tag == "{" + ns["default"] + "}w":
elif child.tag == "{" + ns["default"] + "}pc":
if child.text in {".", ",", ";", ":"}:
if sentence_parts:
sentence_parts[-1] = sentence_parts[-1] + child.text
sentence_parts.append(child.text + " ")
sentence = " ".join(sentence_parts)
perplexity = self.compute_perplexity(sentence)
if (
perplexity == 0.0
): # If the sentence is empty, or contains a number len(snt) < 15, we skip it
s.set("perplexity", str(perplexity))
segment_perplexity += perplexity
num_sentences += 1
if num_sentences != 0:
seg.set("perplexity", str(segment_perplexity / num_sentences))
document_perplexity += segment_perplexity / num_sentences
num_segments += 1
if num_segments != 0:
root.set("perplexity", str(document_perplexity / num_segments))
self.documents[xml_path] = float(document_perplexity / num_segments)
xml_str = ET.tostring(root, encoding="unicode", method="xml")
dom = minidom.parseString(xml_str)
pretty_xml_str = dom.toprettyxml(indent=" ", newl="\n")
pretty_xml_str = os.linesep.join(
[s for s in pretty_xml_str.splitlines() if s.strip()]
with open(
save_path + "/" + os.path.basename(xml_path), "w", encoding="utf-8"
) as xml_file:
def main(self) -> None:
xmls = self.find_xmls()
os.makedirs(self.output_dir, exist_ok=True)
for xml in tqdm(xmls):
if os.path.exists(self.output_dir + "/" + os.path.basename(xml)):
self.process_xml_file(xml, self.output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process XML files.")
"--input_dir", type=str, help="Input directory containing XML files."
"--output_dir", type=str, help="Output directory to save processed files."
parser.add_argument("--cpu", action="store_true", help="Force usage of CPU.")
args = parser.parse_args()
processor = XMLProcessor(args)