Add the docs parsing, basically prediciton of new docs

main
Gašper Spagnolo 2023-08-04 13:20:56 +02:00
parent 568ffd9205
commit 797ba0afea
No known key found for this signature in database
GPG Key ID: 2EA0738CC1EFEEB7
2 changed files with 146 additions and 3 deletions

4
ft.py
View File

@ -17,9 +17,7 @@ class FT:
# load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(
self.model_name
)
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
self.model.to(self.device)

145
generate_docs.py Normal file
View File

@ -0,0 +1,145 @@
import os
import xml.etree.ElementTree as ET
from xml.dom import minidom
import torch
from tqdm import tqdm
import argparse
from typing import List
import warnings
import re
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AdamW
from xml.etree.ElementTree import Element
class XMLProcessor:
def __init__(self, args: argparse.Namespace):
self.input_dir: str = args.input_dir
self.output_dir: str = args.output_dir
self.use_cpu: bool = args.cpu
self.checkpoint_path: str = args.checkpoint_path
self.model: AutoModelForSeq2SeqLM = None
self.tokenizer: AutoTokenizer = None
self.device: torch.device = None
self.documents: dict = {}
def prepare_model(self) -> None:
model_name: str = "facebook/bart-base"
self.device = (
torch.device("cpu")
if self.use_cpu
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
self.model.to(self.device)
if self.checkpoint_path:
if os.path.exists(self.checkpoint_path):
print("Loading checkpoint from", self.checkpoint_path)
cp = torch.load(self.checkpoint_path)
self.model.load_state_dict(cp["model_state_dict"])
def predict(self, text: str) -> str:
input_ids = self.tokenizer(text, return_tensors="pt").input_ids
input_ids = input_ids.to(self.device)
outputs = self.model.generate(input_ids, max_new_tokens=100)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
def contains_number(self, string: str) -> bool:
pattern = r"\d+"
match = re.search(pattern, string)
# Return True if a number is found, otherwise False
return bool(match)
def find_xmls_recursively(self, input_dir: str, xmls: List[str]) -> List[str]:
for entry in os.scandir(input_dir):
if entry.is_file() and entry.name.endswith(".xml"):
xmls.append(entry.path)
elif entry.is_dir():
self.find_xmls_recursively(entry.path, xmls)
return xmls
def find_xmls(self) -> List[str]:
return self.find_xmls_recursively(self.input_dir, [])
def process_xml_file(self, xml_path: str, save_path: str) -> None:
tree = ET.parse(xml_path)
root = tree.getroot()
ns = {"default": "http://www.tei-c.org/ns/1.0"}
for note in root.findall(".//default:note", ns):
sentence = note.text
predicted_sentence = self.predict(sentence)
note.text = predicted_sentence
for seg in root.findall(".//default:seg", ns):
for s in seg.findall(".//default:s", ns):
sentence_parts = []
for child in s:
if child.tag == "{" + ns["default"] + "}w":
sentence_parts.append(child.text)
elif child.tag == "{" + ns["default"] + "}pc":
if child.text in {".", ",", ";", ":"}:
if sentence_parts:
sentence_parts[-1] = sentence_parts[-1] + child.text
else:
sentence_parts.append(child.text)
else:
sentence_parts.append(child.text + " ")
sentence = " ".join(sentence_parts)
predicted_sentence = self.predict(sentence)
s.clear()
tokens = re.findall(r"\w+|\S", sentence)
for token in tokens:
if token in {".", ",", ";", ":"}:
pc = Element("{" + ns["default"] + "}pc")
pc.text = token
s.append(pc)
else:
w = Element("{" + ns["default"] + "}w")
w.text = token
s.append(w)
xml_str = ET.tostring(root, encoding="unicode", method="xml")
dom = minidom.parseString(xml_str)
pretty_xml_str = dom.toprettyxml(indent=" ", newl="\n")
pretty_xml_str = os.linesep.join(
[s for s in pretty_xml_str.splitlines() if s.strip()]
)
with open(
save_path + "/" + os.path.basename(xml_path), "w", encoding="utf-8"
) as xml_file:
xml_file.write(pretty_xml_str)
def main(self) -> None:
self.prepare_model()
xmls = self.find_xmls()
os.makedirs(self.output_dir, exist_ok=True)
for xml in tqdm(xmls):
if os.path.exists(self.output_dir + "/" + os.path.basename(xml)):
continue
self.process_xml_file(xml, self.output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process XML files.")
parser.add_argument(
"--input_dir", type=str, help="Input directory containing XML files."
)
parser.add_argument(
"--output_dir", type=str, help="Output directory to save processed files."
)
parser.add_argument("--cpu", action="store_true", help="Force usage of CPU.")
parser.add_argument(
"--checkpoint_path",
type=str,
help="Path for the checkpoint of weights to use.",
)
args = parser.parse_args()
processor = XMLProcessor(args)
processor.main()