fine-tune/generate_docs.py

146 lines
5.4 KiB
Python

import os
import xml.etree.ElementTree as ET
from xml.dom import minidom
import torch
from tqdm import tqdm
import argparse
from typing import List
import warnings
import re
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AdamW
from xml.etree.ElementTree import Element
class XMLProcessor:
def __init__(self, args: argparse.Namespace):
self.input_dir: str = args.input_dir
self.output_dir: str = args.output_dir
self.use_cpu: bool = args.cpu
self.checkpoint_path: str = args.checkpoint_path
self.model: AutoModelForSeq2SeqLM = None
self.tokenizer: AutoTokenizer = None
self.device: torch.device = None
self.documents: dict = {}
def prepare_model(self) -> None:
model_name: str = "facebook/bart-base"
self.device = (
torch.device("cpu")
if self.use_cpu
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
self.model.to(self.device)
if self.checkpoint_path:
if os.path.exists(self.checkpoint_path):
print("Loading checkpoint from", self.checkpoint_path)
cp = torch.load(self.checkpoint_path)
self.model.load_state_dict(cp["model_state_dict"])
def predict(self, text: str) -> str:
input_ids = self.tokenizer(text, return_tensors="pt").input_ids
input_ids = input_ids.to(self.device)
outputs = self.model.generate(input_ids, max_new_tokens=100)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def contains_number(self, string: str) -> bool:
pattern = r"\d+"
match = re.search(pattern, string)
# Return True if a number is found, otherwise False
return bool(match)
def find_xmls_recursively(self, input_dir: str, xmls: List[str]) -> List[str]:
for entry in os.scandir(input_dir):
if entry.is_file() and entry.name.endswith(".xml"):
xmls.append(entry.path)
elif entry.is_dir():
self.find_xmls_recursively(entry.path, xmls)
return xmls
def find_xmls(self) -> List[str]:
return self.find_xmls_recursively(self.input_dir, [])
def process_xml_file(self, xml_path: str, save_path: str) -> None:
tree = ET.parse(xml_path)
root = tree.getroot()
ns = {"default": "http://www.tei-c.org/ns/1.0"}
for note in root.findall(".//default:note", ns):
sentence = note.text
predicted_sentence = self.predict(sentence)
note.text = predicted_sentence
for seg in root.findall(".//default:seg", ns):
for s in seg.findall(".//default:s", ns):
sentence_parts = []
for child in s:
if child.tag == "{" + ns["default"] + "}w":
sentence_parts.append(child.text)
elif child.tag == "{" + ns["default"] + "}pc":
if child.text in {".", ",", ";", ":"}:
if sentence_parts:
sentence_parts[-1] = sentence_parts[-1] + child.text
else:
sentence_parts.append(child.text)
else:
sentence_parts.append(child.text + " ")
sentence = " ".join(sentence_parts)
predicted_sentence = self.predict(sentence)
s.clear()
tokens = re.findall(r"\w+|\S", sentence)
for token in tokens:
if token in {".", ",", ";", ":"}:
pc = Element("{" + ns["default"] + "}pc")
pc.text = token
s.append(pc)
else:
w = Element("{" + ns["default"] + "}w")
w.text = token
s.append(w)
xml_str = ET.tostring(root, encoding="unicode", method="xml")
dom = minidom.parseString(xml_str)
pretty_xml_str = dom.toprettyxml(indent=" ", newl="\n")
pretty_xml_str = os.linesep.join(
[s for s in pretty_xml_str.splitlines() if s.strip()]
)
with open(
save_path + "/" + os.path.basename(xml_path), "w", encoding="utf-8"
) as xml_file:
xml_file.write(pretty_xml_str)
def main(self) -> None:
self.prepare_model()
xmls = self.find_xmls()
os.makedirs(self.output_dir, exist_ok=True)
for xml in tqdm(xmls):
if os.path.exists(self.output_dir + "/" + os.path.basename(xml)):
continue
self.process_xml_file(xml, self.output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process XML files.")
parser.add_argument(
"--input_dir", type=str, help="Input directory containing XML files."
)
parser.add_argument(
"--output_dir", type=str, help="Output directory to save processed files."
)
parser.add_argument("--cpu", action="store_true", help="Force usage of CPU.")
parser.add_argument(
"--checkpoint_path",
type=str,
help="Path for the checkpoint of weights to use.",
)
args = parser.parse_args()
processor = XMLProcessor(args)
processor.main()