xml-ocr/rate.py

151 lines
4.6 KiB
Python

import os
import xml.etree.ElementTree as ET
import argparse
from typing import List
import json
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
class Rate:
def __init__(self, args: argparse.Namespace):
self.input_dir: str = args.input_dir
self.output_dir: str = args.output_dir
self.use_cpu: bool = args.cpu
self.docs: dict = {}
self.plots_dir: str = "./plots"
os.makedirs(self.plots_dir, exist_ok=True)
def find_xmls_recursively(self, input_dir: str, xmls: List[str]) -> List[str]:
for entry in os.scandir(input_dir):
if entry.is_file() and entry.name.endswith(".xml"):
xmls.append(entry.path)
elif entry.is_dir():
self.find_xmls_recursively(entry.path, xmls)
return xmls
def find_xmls(self) -> List[str]:
return self.find_xmls_recursively(self.input_dir, [])
def get_doc_perplexity(self, filepath: str) -> float:
tree = ET.parse(filepath)
root = tree.getroot()
try:
perplexity = float(root.attrib["perplexity"])
except KeyError:
perplexity = -100
return perplexity
def parse_docs(self, xmls: List[str]) -> None:
docs = {}
for xml in tqdm(xmls):
perplexity = self.get_doc_perplexity(xml)
docs[xml] = perplexity
self.docs = docs
def sort_docs(self) -> None:
self.docs = dict(sorted(self.docs.items(), key=lambda item: item[1]))
def save_docs(self) -> None:
with open(os.path.join(self.plots_dir, "docs.json"), "w") as f:
json.dump(self.docs, f)
def histogram_of_perplexities(self) -> None:
_, values = zip(*self.docs.items())
# Create histogram
plt.figure(figsize=(10, 6))
plt.hist(values, bins=20, edgecolor="black")
plt.xlabel("Perplexity")
plt.ylabel("Number of Documents")
plt.title("Histogram of Document Perplexities")
plt.tight_layout()
# Save figure
plt.savefig(os.path.join(self.plots_dir, "histogram_of_perplexities.png"))
plt.close()
plt.clf()
def cumulative_distribution_of_perplexities(self) -> None:
_, values = zip(*self.docs.items())
# Create cumulative distribution plot
plt.figure(figsize=(10, 6))
plt.hist(values, bins=20, cumulative=True, edgecolor="black")
plt.xlabel("Perplexity")
plt.ylabel("Cumulative Number of Documents")
plt.title("Cumulative Distribution of Document Perplexities")
plt.tight_layout()
# Save figure
plt.savefig(
os.path.join(self.plots_dir, "cumulative_distribution_of_perplexities.png")
)
plt.close()
plt.clf()
def boxplot_of_perplexities(self) -> None:
_, values = zip(*self.docs.items())
# Create boxplot
plt.figure(figsize=(10, 6))
plt.boxplot(values)
plt.ylabel("Perplexity")
plt.title("Boxplot of Document Perplexities")
plt.tight_layout()
# Save figure
plt.savefig(os.path.join(self.plots_dir, "boxplot_of_perplexities.png"))
plt.close()
plt.clf()
def barplot_of_perplexities(self) -> None:
# Prepare data for visualization
labels, values = zip(*self.docs.items())
indexes = np.arange(len(labels))
# Create bar plot
plt.figure(figsize=(20, 10))
plt.bar(indexes, values, align="center")
plt.xticks(indexes, labels, rotation="vertical")
plt.ylabel("Perplexity")
plt.xlabel("Document")
plt.title("Perplexity of documents")
plt.tight_layout()
# Save figure
plt.savefig(os.path.join(self.plots_dir, "barplot_of_perplexities.png"))
plt.close()
plt.clf()
def pltt(self) -> None:
self.histogram_of_perplexities()
self.cumulative_distribution_of_perplexities()
self.boxplot_of_perplexities()
self.barplot_of_perplexities()
def main(self):
xmls = self.find_xmls()
self.parse_docs(xmls)
self.sort_docs()
self.save_docs()
self.pltt()
exit(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process XML files.")
parser.add_argument(
"--input_dir", type=str, help="Input directory containing XML files."
)
parser.add_argument(
"--output_dir", type=str, help="Output directory to save processed files."
)
parser.add_argument("--cpu", action="store_true", help="Force usage of CPU.")
args = parser.parse_args()
rater = Rate(args)
rater.main()