import os import xml.etree.ElementTree as ET from xml.dom import minidom from transformers import XLMRobertaTokenizer, XLMRobertaForMaskedLM import torch from tqdm import tqdm import argparse from typing import List import warnings class XMLProcessor: def __init__(self, args: argparse.Namespace): self.input_dir: str = args.input_dir self.output_dir: str = args.output_dir self.use_cpu: bool = args.cpu self.model: XLMRobertaForMaskedLM = None self.tokenizer: XLMRobertaTokenizer = None self.device: torch.device = None def prepare_model(self) -> None: model_name: str = "xlm-roberta-base" self.model = XLMRobertaForMaskedLM.from_pretrained(model_name) self.device = ( torch.device("cpu") if self.use_cpu else torch.device("cuda" if torch.cuda.is_available() else "cpu") ) self.model = self.model.to(self.device) self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_name) def compute_perplexity(self, text: str) -> float: max_length = 512 # XLM-Roberta's maximum sequence length total_loss = 0.0 total_count = 0 if text is None or len(text) == 0: warnings.warn("Empty text") return 0.0 while len(text) > max_length: split_index = text[:max_length].rfind(" ") if split_index == -1: # If there's no space to split on, just truncate split_index = max_length sentence = text[:split_index].strip() text = text[split_index:].strip() inputs = self.tokenizer( sentence, truncation=True, max_length=max_length, return_tensors="pt" ) inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()} outputs = self.model(**inputs, labels=inputs["input_ids"]) total_loss += torch.exp(outputs.loss).item() total_count += 1 # Process the remaining text if len(text) > 0: inputs = self.tokenizer( text, truncation=True, max_length=max_length, return_tensors="pt" ) inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()} outputs = self.model(**inputs, labels=inputs["input_ids"]) total_loss += torch.exp(outputs.loss).item() total_count += 1 average_loss = total_loss / total_count if total_count > 0 else 0.0 return average_loss def find_xmls_recursively(self, input_dir: str, xmls: List[str]) -> List[str]: for entry in os.scandir(input_dir): if entry.is_file() and entry.name.endswith(".xml"): xmls.append(entry.path) elif entry.is_dir(): self.find_xmls_recursively(entry.path, xmls) return xmls def find_xmls(self) -> List[str]: return self.find_xmls_recursively(self.input_dir, []) def process_xml_file(self, xml_path: str, save_path: str) -> None: tree = ET.parse(xml_path) root = tree.getroot() ns = {"default": "http://www.tei-c.org/ns/1.0"} document_perplexity = 0 num_segments = 0 for note in root.findall(".//default:note", ns): sentence = note.text perplexity = self.compute_perplexity(sentence) note.set("perplexity", str(perplexity)) document_perplexity += perplexity num_segments += 1 for seg in root.findall(".//default:seg", ns): segment_perplexity = 0 num_sentences = 0 for s in seg.findall(".//default:s", ns): sentence = " ".join([w.text for w in s.findall(".//default:w", ns)]) perplexity = self.compute_perplexity(sentence) s.set("perplexity", str(perplexity)) segment_perplexity += perplexity num_sentences += 1 if num_sentences != 0: seg.set("perplexity", str(segment_perplexity / num_sentences)) document_perplexity += segment_perplexity / num_sentences num_segments += 1 if num_segments != 0: root.set("perplexity", str(document_perplexity / num_segments)) xml_str = ET.tostring(root, encoding="unicode", method="xml") dom = minidom.parseString(xml_str) pretty_xml_str = dom.toprettyxml(indent=" ", newl="\n") pretty_xml_str = os.linesep.join( [s for s in pretty_xml_str.splitlines() if s.strip()] ) with open( save_path + "/" + os.path.basename(xml_path), "w", encoding="utf-8" ) as xml_file: xml_file.write(pretty_xml_str) def main(self) -> None: self.prepare_model() xmls = self.find_xmls() os.makedirs(self.output_dir, exist_ok=True) for xml in tqdm(xmls): if os.path.exists(self.output_dir + "/" + os.path.basename(xml)): continue self.process_xml_file(xml, self.output_dir) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Process XML files.") parser.add_argument( "--input_dir", type=str, help="Input directory containing XML files." ) parser.add_argument( "--output_dir", type=str, help="Output directory to save processed files." ) parser.add_argument("--cpu", action="store_true", help="Force usage of CPU.") args = parser.parse_args() processor = XMLProcessor(args) processor.main()