From 797ba0afea509026a91715508ea71b4a20c1d38b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C5=A1per=20Spagnolo?= Date: Fri, 4 Aug 2023 13:20:56 +0200 Subject: [PATCH] Add the docs parsing, basically prediciton of new docs --- ft.py | 4 +- generate_docs.py | 145 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 3 deletions(-) create mode 100644 generate_docs.py diff --git a/ft.py b/ft.py index 6fc32c2..fcb97b1 100644 --- a/ft.py +++ b/ft.py @@ -17,9 +17,7 @@ class FT: # load tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) - self.model = AutoModelForSeq2SeqLM.from_pretrained( - self.model_name - ) + self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name) self.model.to(self.device) diff --git a/generate_docs.py b/generate_docs.py new file mode 100644 index 0000000..cc53a07 --- /dev/null +++ b/generate_docs.py @@ -0,0 +1,145 @@ +import os +import xml.etree.ElementTree as ET +from xml.dom import minidom +import torch +from tqdm import tqdm +import argparse +from typing import List +import warnings +import re +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AdamW +from xml.etree.ElementTree import Element + + +class XMLProcessor: + def __init__(self, args: argparse.Namespace): + self.input_dir: str = args.input_dir + self.output_dir: str = args.output_dir + self.use_cpu: bool = args.cpu + self.checkpoint_path: str = args.checkpoint_path + self.model: AutoModelForSeq2SeqLM = None + self.tokenizer: AutoTokenizer = None + self.device: torch.device = None + self.documents: dict = {} + + def prepare_model(self) -> None: + model_name: str = "facebook/bart-base" + self.device = ( + torch.device("cpu") + if self.use_cpu + else torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + + self.model.to(self.device) + + if self.checkpoint_path: + if os.path.exists(self.checkpoint_path): + print("Loading checkpoint from", self.checkpoint_path) + cp = torch.load(self.checkpoint_path) + self.model.load_state_dict(cp["model_state_dict"]) + + def predict(self, text: str) -> str: + input_ids = self.tokenizer(text, return_tensors="pt").input_ids + input_ids = input_ids.to(self.device) + outputs = self.model.generate(input_ids, max_new_tokens=100) + return self.tokenizer.decode(outputs[0], skip_special_tokens=True) + + def contains_number(self, string: str) -> bool: + pattern = r"\d+" + match = re.search(pattern, string) + # Return True if a number is found, otherwise False + return bool(match) + + def find_xmls_recursively(self, input_dir: str, xmls: List[str]) -> List[str]: + for entry in os.scandir(input_dir): + if entry.is_file() and entry.name.endswith(".xml"): + xmls.append(entry.path) + elif entry.is_dir(): + self.find_xmls_recursively(entry.path, xmls) + return xmls + + def find_xmls(self) -> List[str]: + return self.find_xmls_recursively(self.input_dir, []) + + def process_xml_file(self, xml_path: str, save_path: str) -> None: + tree = ET.parse(xml_path) + root = tree.getroot() + ns = {"default": "http://www.tei-c.org/ns/1.0"} + for note in root.findall(".//default:note", ns): + sentence = note.text + predicted_sentence = self.predict(sentence) + note.text = predicted_sentence + + for seg in root.findall(".//default:seg", ns): + for s in seg.findall(".//default:s", ns): + sentence_parts = [] + + for child in s: + if child.tag == "{" + ns["default"] + "}w": + sentence_parts.append(child.text) + elif child.tag == "{" + ns["default"] + "}pc": + if child.text in {".", ",", ";", ":"}: + if sentence_parts: + sentence_parts[-1] = sentence_parts[-1] + child.text + else: + sentence_parts.append(child.text) + else: + sentence_parts.append(child.text + " ") + + sentence = " ".join(sentence_parts) + predicted_sentence = self.predict(sentence) + + s.clear() + tokens = re.findall(r"\w+|\S", sentence) + for token in tokens: + if token in {".", ",", ";", ":"}: + pc = Element("{" + ns["default"] + "}pc") + pc.text = token + s.append(pc) + else: + w = Element("{" + ns["default"] + "}w") + w.text = token + s.append(w) + + xml_str = ET.tostring(root, encoding="unicode", method="xml") + dom = minidom.parseString(xml_str) + pretty_xml_str = dom.toprettyxml(indent=" ", newl="\n") + pretty_xml_str = os.linesep.join( + [s for s in pretty_xml_str.splitlines() if s.strip()] + ) + with open( + save_path + "/" + os.path.basename(xml_path), "w", encoding="utf-8" + ) as xml_file: + xml_file.write(pretty_xml_str) + + def main(self) -> None: + self.prepare_model() + xmls = self.find_xmls() + os.makedirs(self.output_dir, exist_ok=True) + + for xml in tqdm(xmls): + if os.path.exists(self.output_dir + "/" + os.path.basename(xml)): + continue + self.process_xml_file(xml, self.output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process XML files.") + parser.add_argument( + "--input_dir", type=str, help="Input directory containing XML files." + ) + parser.add_argument( + "--output_dir", type=str, help="Output directory to save processed files." + ) + parser.add_argument("--cpu", action="store_true", help="Force usage of CPU.") + parser.add_argument( + "--checkpoint_path", + type=str, + help="Path for the checkpoint of weights to use.", + ) + args = parser.parse_args() + + processor = XMLProcessor(args) + processor.main()