import os import xml.etree.ElementTree as ET from xml.dom import minidom import torch from tqdm import tqdm import argparse from typing import List import warnings import re from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AdamW from xml.etree.ElementTree import Element class XMLProcessor: def __init__(self, args: argparse.Namespace): self.input_dir: str = args.input_dir self.output_dir: str = args.output_dir self.use_cpu: bool = args.cpu self.checkpoint_path: str = args.checkpoint_path self.model: AutoModelForSeq2SeqLM = None self.tokenizer: AutoTokenizer = None self.device: torch.device = None self.documents: dict = {} def prepare_model(self) -> None: model_name: str = "facebook/bart-base" self.device = ( torch.device("cpu") if self.use_cpu else torch.device("cuda" if torch.cuda.is_available() else "cpu") ) self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) self.model.to(self.device) if self.checkpoint_path: if os.path.exists(self.checkpoint_path): print("Loading checkpoint from", self.checkpoint_path) cp = torch.load(self.checkpoint_path) self.model.load_state_dict(cp["model_state_dict"]) def predict(self, text: str) -> str: input_ids = self.tokenizer(text, return_tensors="pt").input_ids input_ids = input_ids.to(self.device) outputs = self.model.generate(input_ids, max_new_tokens=100) return self.tokenizer.decode(outputs[0], skip_special_tokens=True) def contains_number(self, string: str) -> bool: pattern = r"\d+" match = re.search(pattern, string) # Return True if a number is found, otherwise False return bool(match) def find_xmls_recursively(self, input_dir: str, xmls: List[str]) -> List[str]: for entry in os.scandir(input_dir): if entry.is_file() and entry.name.endswith(".xml"): xmls.append(entry.path) elif entry.is_dir(): self.find_xmls_recursively(entry.path, xmls) return xmls def find_xmls(self) -> List[str]: return self.find_xmls_recursively(self.input_dir, []) def process_xml_file(self, xml_path: str, save_path: str) -> None: tree = ET.parse(xml_path) root = tree.getroot() ns = {"default": "http://www.tei-c.org/ns/1.0"} for note in root.findall(".//default:note", ns): sentence = note.text predicted_sentence = self.predict(sentence) note.text = predicted_sentence for seg in root.findall(".//default:seg", ns): for s in seg.findall(".//default:s", ns): sentence_parts = [] for child in s: if child.tag == "{" + ns["default"] + "}w": sentence_parts.append(child.text) elif child.tag == "{" + ns["default"] + "}pc": if child.text in {".", ",", ";", ":"}: if sentence_parts: sentence_parts[-1] = sentence_parts[-1] + child.text else: sentence_parts.append(child.text) else: sentence_parts.append(child.text + " ") sentence = " ".join(sentence_parts) predicted_sentence = self.predict(sentence) s.clear() tokens = re.findall(r"\w+|\S", sentence) for token in tokens: if token in {".", ",", ";", ":"}: pc = Element("{" + ns["default"] + "}pc") pc.text = token s.append(pc) else: w = Element("{" + ns["default"] + "}w") w.text = token s.append(w) xml_str = ET.tostring(root, encoding="unicode", method="xml") dom = minidom.parseString(xml_str) pretty_xml_str = dom.toprettyxml(indent=" ", newl="\n") pretty_xml_str = os.linesep.join( [s for s in pretty_xml_str.splitlines() if s.strip()] ) with open( save_path + "/" + os.path.basename(xml_path), "w", encoding="utf-8" ) as xml_file: xml_file.write(pretty_xml_str) def main(self) -> None: self.prepare_model() xmls = self.find_xmls() os.makedirs(self.output_dir, exist_ok=True) for xml in tqdm(xmls): if os.path.exists(self.output_dir + "/" + os.path.basename(xml)): continue self.process_xml_file(xml, self.output_dir) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Process XML files.") parser.add_argument( "--input_dir", type=str, help="Input directory containing XML files." ) parser.add_argument( "--output_dir", type=str, help="Output directory to save processed files." ) parser.add_argument("--cpu", action="store_true", help="Force usage of CPU.") parser.add_argument( "--checkpoint_path", type=str, help="Path for the checkpoint of weights to use.", ) args = parser.parse_args() processor = XMLProcessor(args) processor.main()