Add the docs parsing, basically prediciton of new docs
parent
568ffd9205
commit
797ba0afea
4
ft.py
4
ft.py
|
@ -17,9 +17,7 @@ class FT:
|
|||
|
||||
# load tokenizer and model
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
self.model_name
|
||||
)
|
||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
|
||||
|
||||
self.model.to(self.device)
|
||||
|
||||
|
|
|
@ -0,0 +1,145 @@
|
|||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
from xml.dom import minidom
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
from typing import List
|
||||
import warnings
|
||||
import re
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AdamW
|
||||
from xml.etree.ElementTree import Element
|
||||
|
||||
|
||||
class XMLProcessor:
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.input_dir: str = args.input_dir
|
||||
self.output_dir: str = args.output_dir
|
||||
self.use_cpu: bool = args.cpu
|
||||
self.checkpoint_path: str = args.checkpoint_path
|
||||
self.model: AutoModelForSeq2SeqLM = None
|
||||
self.tokenizer: AutoTokenizer = None
|
||||
self.device: torch.device = None
|
||||
self.documents: dict = {}
|
||||
|
||||
def prepare_model(self) -> None:
|
||||
model_name: str = "facebook/bart-base"
|
||||
self.device = (
|
||||
torch.device("cpu")
|
||||
if self.use_cpu
|
||||
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
||||
|
||||
self.model.to(self.device)
|
||||
|
||||
if self.checkpoint_path:
|
||||
if os.path.exists(self.checkpoint_path):
|
||||
print("Loading checkpoint from", self.checkpoint_path)
|
||||
cp = torch.load(self.checkpoint_path)
|
||||
self.model.load_state_dict(cp["model_state_dict"])
|
||||
|
||||
def predict(self, text: str) -> str:
|
||||
input_ids = self.tokenizer(text, return_tensors="pt").input_ids
|
||||
input_ids = input_ids.to(self.device)
|
||||
outputs = self.model.generate(input_ids, max_new_tokens=100)
|
||||
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
|
||||
def contains_number(self, string: str) -> bool:
|
||||
pattern = r"\d+"
|
||||
match = re.search(pattern, string)
|
||||
# Return True if a number is found, otherwise False
|
||||
return bool(match)
|
||||
|
||||
def find_xmls_recursively(self, input_dir: str, xmls: List[str]) -> List[str]:
|
||||
for entry in os.scandir(input_dir):
|
||||
if entry.is_file() and entry.name.endswith(".xml"):
|
||||
xmls.append(entry.path)
|
||||
elif entry.is_dir():
|
||||
self.find_xmls_recursively(entry.path, xmls)
|
||||
return xmls
|
||||
|
||||
def find_xmls(self) -> List[str]:
|
||||
return self.find_xmls_recursively(self.input_dir, [])
|
||||
|
||||
def process_xml_file(self, xml_path: str, save_path: str) -> None:
|
||||
tree = ET.parse(xml_path)
|
||||
root = tree.getroot()
|
||||
ns = {"default": "http://www.tei-c.org/ns/1.0"}
|
||||
for note in root.findall(".//default:note", ns):
|
||||
sentence = note.text
|
||||
predicted_sentence = self.predict(sentence)
|
||||
note.text = predicted_sentence
|
||||
|
||||
for seg in root.findall(".//default:seg", ns):
|
||||
for s in seg.findall(".//default:s", ns):
|
||||
sentence_parts = []
|
||||
|
||||
for child in s:
|
||||
if child.tag == "{" + ns["default"] + "}w":
|
||||
sentence_parts.append(child.text)
|
||||
elif child.tag == "{" + ns["default"] + "}pc":
|
||||
if child.text in {".", ",", ";", ":"}:
|
||||
if sentence_parts:
|
||||
sentence_parts[-1] = sentence_parts[-1] + child.text
|
||||
else:
|
||||
sentence_parts.append(child.text)
|
||||
else:
|
||||
sentence_parts.append(child.text + " ")
|
||||
|
||||
sentence = " ".join(sentence_parts)
|
||||
predicted_sentence = self.predict(sentence)
|
||||
|
||||
s.clear()
|
||||
tokens = re.findall(r"\w+|\S", sentence)
|
||||
for token in tokens:
|
||||
if token in {".", ",", ";", ":"}:
|
||||
pc = Element("{" + ns["default"] + "}pc")
|
||||
pc.text = token
|
||||
s.append(pc)
|
||||
else:
|
||||
w = Element("{" + ns["default"] + "}w")
|
||||
w.text = token
|
||||
s.append(w)
|
||||
|
||||
xml_str = ET.tostring(root, encoding="unicode", method="xml")
|
||||
dom = minidom.parseString(xml_str)
|
||||
pretty_xml_str = dom.toprettyxml(indent=" ", newl="\n")
|
||||
pretty_xml_str = os.linesep.join(
|
||||
[s for s in pretty_xml_str.splitlines() if s.strip()]
|
||||
)
|
||||
with open(
|
||||
save_path + "/" + os.path.basename(xml_path), "w", encoding="utf-8"
|
||||
) as xml_file:
|
||||
xml_file.write(pretty_xml_str)
|
||||
|
||||
def main(self) -> None:
|
||||
self.prepare_model()
|
||||
xmls = self.find_xmls()
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
for xml in tqdm(xmls):
|
||||
if os.path.exists(self.output_dir + "/" + os.path.basename(xml)):
|
||||
continue
|
||||
self.process_xml_file(xml, self.output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Process XML files.")
|
||||
parser.add_argument(
|
||||
"--input_dir", type=str, help="Input directory containing XML files."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, help="Output directory to save processed files."
|
||||
)
|
||||
parser.add_argument("--cpu", action="store_true", help="Force usage of CPU.")
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
type=str,
|
||||
help="Path for the checkpoint of weights to use.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
processor = XMLProcessor(args)
|
||||
processor.main()
|
Loading…
Reference in New Issue