152 lines
5.5 KiB
Python
152 lines
5.5 KiB
Python
|
import os
|
||
|
import xml.etree.ElementTree as ET
|
||
|
from xml.dom import minidom
|
||
|
from transformers import XLMRobertaTokenizer, XLMRobertaForMaskedLM
|
||
|
import torch
|
||
|
from tqdm import tqdm
|
||
|
import argparse
|
||
|
from typing import List
|
||
|
import warnings
|
||
|
|
||
|
|
||
|
class XMLProcessor:
|
||
|
def __init__(self, args: argparse.Namespace):
|
||
|
self.input_dir: str = args.input_dir
|
||
|
self.output_dir: str = args.output_dir
|
||
|
self.use_cpu: bool = args.cpu
|
||
|
self.model: XLMRobertaForMaskedLM = None
|
||
|
self.tokenizer: XLMRobertaTokenizer = None
|
||
|
self.device: torch.device = None
|
||
|
|
||
|
def prepare_model(self) -> None:
|
||
|
model_name: str = "xlm-roberta-base"
|
||
|
self.model = XLMRobertaForMaskedLM.from_pretrained(model_name)
|
||
|
self.device = (
|
||
|
torch.device("cpu")
|
||
|
if self.use_cpu
|
||
|
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
)
|
||
|
self.model = self.model.to(self.device)
|
||
|
self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_name)
|
||
|
|
||
|
def compute_perplexity(self, text: str) -> float:
|
||
|
max_length = 512 # XLM-Roberta's maximum sequence length
|
||
|
|
||
|
total_loss = 0.0
|
||
|
total_count = 0
|
||
|
|
||
|
if text is None or len(text) == 0:
|
||
|
warnings.warn("Empty text")
|
||
|
return 0.0
|
||
|
|
||
|
while len(text) > max_length:
|
||
|
split_index = text[:max_length].rfind(" ")
|
||
|
if split_index == -1: # If there's no space to split on, just truncate
|
||
|
split_index = max_length
|
||
|
|
||
|
sentence = text[:split_index].strip()
|
||
|
text = text[split_index:].strip()
|
||
|
|
||
|
inputs = self.tokenizer(
|
||
|
sentence, truncation=True, max_length=max_length, return_tensors="pt"
|
||
|
)
|
||
|
inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()}
|
||
|
outputs = self.model(**inputs, labels=inputs["input_ids"])
|
||
|
|
||
|
total_loss += torch.exp(outputs.loss).item()
|
||
|
total_count += 1
|
||
|
|
||
|
# Process the remaining text
|
||
|
if len(text) > 0:
|
||
|
inputs = self.tokenizer(
|
||
|
text, truncation=True, max_length=max_length, return_tensors="pt"
|
||
|
)
|
||
|
inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()}
|
||
|
outputs = self.model(**inputs, labels=inputs["input_ids"])
|
||
|
|
||
|
total_loss += torch.exp(outputs.loss).item()
|
||
|
total_count += 1
|
||
|
|
||
|
average_loss = total_loss / total_count if total_count > 0 else 0.0
|
||
|
|
||
|
return average_loss
|
||
|
|
||
|
def find_xmls_recursively(self, input_dir: str, xmls: List[str]) -> List[str]:
|
||
|
for entry in os.scandir(input_dir):
|
||
|
if entry.is_file() and entry.name.endswith(".xml"):
|
||
|
xmls.append(entry.path)
|
||
|
elif entry.is_dir():
|
||
|
self.find_xmls_recursively(entry.path, xmls)
|
||
|
return xmls
|
||
|
|
||
|
def find_xmls(self) -> List[str]:
|
||
|
return self.find_xmls_recursively(self.input_dir, [])
|
||
|
|
||
|
def process_xml_file(self, xml_path: str, save_path: str) -> None:
|
||
|
tree = ET.parse(xml_path)
|
||
|
root = tree.getroot()
|
||
|
ns = {"default": "http://www.tei-c.org/ns/1.0"}
|
||
|
document_perplexity = 0
|
||
|
num_segments = 0
|
||
|
|
||
|
for note in root.findall(".//default:note", ns):
|
||
|
sentence = note.text
|
||
|
perplexity = self.compute_perplexity(sentence)
|
||
|
note.set("perplexity", str(perplexity))
|
||
|
document_perplexity += perplexity
|
||
|
num_segments += 1
|
||
|
|
||
|
for seg in root.findall(".//default:seg", ns):
|
||
|
segment_perplexity = 0
|
||
|
num_sentences = 0
|
||
|
for s in seg.findall(".//default:s", ns):
|
||
|
sentence = " ".join([w.text for w in s.findall(".//default:w", ns)])
|
||
|
perplexity = self.compute_perplexity(sentence)
|
||
|
s.set("perplexity", str(perplexity))
|
||
|
segment_perplexity += perplexity
|
||
|
num_sentences += 1
|
||
|
|
||
|
if num_sentences != 0:
|
||
|
seg.set("perplexity", str(segment_perplexity / num_sentences))
|
||
|
document_perplexity += segment_perplexity / num_sentences
|
||
|
num_segments += 1
|
||
|
|
||
|
if num_segments != 0:
|
||
|
root.set("perplexity", str(document_perplexity / num_segments))
|
||
|
|
||
|
xml_str = ET.tostring(root, encoding="unicode", method="xml")
|
||
|
dom = minidom.parseString(xml_str)
|
||
|
pretty_xml_str = dom.toprettyxml(indent=" ", newl="\n")
|
||
|
pretty_xml_str = os.linesep.join(
|
||
|
[s for s in pretty_xml_str.splitlines() if s.strip()]
|
||
|
)
|
||
|
with open(
|
||
|
save_path + "/" + os.path.basename(xml_path), "w", encoding="utf-8"
|
||
|
) as xml_file:
|
||
|
xml_file.write(pretty_xml_str)
|
||
|
|
||
|
def main(self) -> None:
|
||
|
self.prepare_model()
|
||
|
xmls = self.find_xmls()
|
||
|
os.makedirs(self.output_dir, exist_ok=True)
|
||
|
|
||
|
for xml in tqdm(xmls):
|
||
|
if os.path.exists(self.output_dir + "/" + os.path.basename(xml)):
|
||
|
continue
|
||
|
self.process_xml_file(xml, self.output_dir)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = argparse.ArgumentParser(description="Process XML files.")
|
||
|
parser.add_argument(
|
||
|
"--input_dir", type=str, help="Input directory containing XML files."
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--output_dir", type=str, help="Output directory to save processed files."
|
||
|
)
|
||
|
parser.add_argument("--cpu", action="store_true", help="Force usage of CPU.")
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
processor = XMLProcessor(args)
|
||
|
processor.main()
|