146 lines
5.4 KiB
Python
146 lines
5.4 KiB
Python
|
import os
|
||
|
import xml.etree.ElementTree as ET
|
||
|
from xml.dom import minidom
|
||
|
import torch
|
||
|
from tqdm import tqdm
|
||
|
import argparse
|
||
|
from typing import List
|
||
|
import warnings
|
||
|
import re
|
||
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AdamW
|
||
|
from xml.etree.ElementTree import Element
|
||
|
|
||
|
|
||
|
class XMLProcessor:
|
||
|
def __init__(self, args: argparse.Namespace):
|
||
|
self.input_dir: str = args.input_dir
|
||
|
self.output_dir: str = args.output_dir
|
||
|
self.use_cpu: bool = args.cpu
|
||
|
self.checkpoint_path: str = args.checkpoint_path
|
||
|
self.model: AutoModelForSeq2SeqLM = None
|
||
|
self.tokenizer: AutoTokenizer = None
|
||
|
self.device: torch.device = None
|
||
|
self.documents: dict = {}
|
||
|
|
||
|
def prepare_model(self) -> None:
|
||
|
model_name: str = "facebook/bart-base"
|
||
|
self.device = (
|
||
|
torch.device("cpu")
|
||
|
if self.use_cpu
|
||
|
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
)
|
||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
||
|
|
||
|
self.model.to(self.device)
|
||
|
|
||
|
if self.checkpoint_path:
|
||
|
if os.path.exists(self.checkpoint_path):
|
||
|
print("Loading checkpoint from", self.checkpoint_path)
|
||
|
cp = torch.load(self.checkpoint_path)
|
||
|
self.model.load_state_dict(cp["model_state_dict"])
|
||
|
|
||
|
def predict(self, text: str) -> str:
|
||
|
input_ids = self.tokenizer(text, return_tensors="pt").input_ids
|
||
|
input_ids = input_ids.to(self.device)
|
||
|
outputs = self.model.generate(input_ids, max_new_tokens=100)
|
||
|
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||
|
|
||
|
def contains_number(self, string: str) -> bool:
|
||
|
pattern = r"\d+"
|
||
|
match = re.search(pattern, string)
|
||
|
# Return True if a number is found, otherwise False
|
||
|
return bool(match)
|
||
|
|
||
|
def find_xmls_recursively(self, input_dir: str, xmls: List[str]) -> List[str]:
|
||
|
for entry in os.scandir(input_dir):
|
||
|
if entry.is_file() and entry.name.endswith(".xml"):
|
||
|
xmls.append(entry.path)
|
||
|
elif entry.is_dir():
|
||
|
self.find_xmls_recursively(entry.path, xmls)
|
||
|
return xmls
|
||
|
|
||
|
def find_xmls(self) -> List[str]:
|
||
|
return self.find_xmls_recursively(self.input_dir, [])
|
||
|
|
||
|
def process_xml_file(self, xml_path: str, save_path: str) -> None:
|
||
|
tree = ET.parse(xml_path)
|
||
|
root = tree.getroot()
|
||
|
ns = {"default": "http://www.tei-c.org/ns/1.0"}
|
||
|
for note in root.findall(".//default:note", ns):
|
||
|
sentence = note.text
|
||
|
predicted_sentence = self.predict(sentence)
|
||
|
note.text = predicted_sentence
|
||
|
|
||
|
for seg in root.findall(".//default:seg", ns):
|
||
|
for s in seg.findall(".//default:s", ns):
|
||
|
sentence_parts = []
|
||
|
|
||
|
for child in s:
|
||
|
if child.tag == "{" + ns["default"] + "}w":
|
||
|
sentence_parts.append(child.text)
|
||
|
elif child.tag == "{" + ns["default"] + "}pc":
|
||
|
if child.text in {".", ",", ";", ":"}:
|
||
|
if sentence_parts:
|
||
|
sentence_parts[-1] = sentence_parts[-1] + child.text
|
||
|
else:
|
||
|
sentence_parts.append(child.text)
|
||
|
else:
|
||
|
sentence_parts.append(child.text + " ")
|
||
|
|
||
|
sentence = " ".join(sentence_parts)
|
||
|
predicted_sentence = self.predict(sentence)
|
||
|
|
||
|
s.clear()
|
||
|
tokens = re.findall(r"\w+|\S", sentence)
|
||
|
for token in tokens:
|
||
|
if token in {".", ",", ";", ":"}:
|
||
|
pc = Element("{" + ns["default"] + "}pc")
|
||
|
pc.text = token
|
||
|
s.append(pc)
|
||
|
else:
|
||
|
w = Element("{" + ns["default"] + "}w")
|
||
|
w.text = token
|
||
|
s.append(w)
|
||
|
|
||
|
xml_str = ET.tostring(root, encoding="unicode", method="xml")
|
||
|
dom = minidom.parseString(xml_str)
|
||
|
pretty_xml_str = dom.toprettyxml(indent=" ", newl="\n")
|
||
|
pretty_xml_str = os.linesep.join(
|
||
|
[s for s in pretty_xml_str.splitlines() if s.strip()]
|
||
|
)
|
||
|
with open(
|
||
|
save_path + "/" + os.path.basename(xml_path), "w", encoding="utf-8"
|
||
|
) as xml_file:
|
||
|
xml_file.write(pretty_xml_str)
|
||
|
|
||
|
def main(self) -> None:
|
||
|
self.prepare_model()
|
||
|
xmls = self.find_xmls()
|
||
|
os.makedirs(self.output_dir, exist_ok=True)
|
||
|
|
||
|
for xml in tqdm(xmls):
|
||
|
if os.path.exists(self.output_dir + "/" + os.path.basename(xml)):
|
||
|
continue
|
||
|
self.process_xml_file(xml, self.output_dir)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = argparse.ArgumentParser(description="Process XML files.")
|
||
|
parser.add_argument(
|
||
|
"--input_dir", type=str, help="Input directory containing XML files."
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--output_dir", type=str, help="Output directory to save processed files."
|
||
|
)
|
||
|
parser.add_argument("--cpu", action="store_true", help="Force usage of CPU.")
|
||
|
parser.add_argument(
|
||
|
"--checkpoint_path",
|
||
|
type=str,
|
||
|
help="Path for the checkpoint of weights to use.",
|
||
|
)
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
processor = XMLProcessor(args)
|
||
|
processor.main()
|