151 lines
4.6 KiB
Python
151 lines
4.6 KiB
Python
import os
|
|
import xml.etree.ElementTree as ET
|
|
import argparse
|
|
from typing import List
|
|
import json
|
|
import matplotlib.pyplot as plt
|
|
from tqdm import tqdm
|
|
import numpy as np
|
|
|
|
|
|
class Rate:
|
|
def __init__(self, args: argparse.Namespace):
|
|
self.input_dir: str = args.input_dir
|
|
self.output_dir: str = args.output_dir
|
|
self.use_cpu: bool = args.cpu
|
|
self.docs: dict = {}
|
|
self.plots_dir: str = "./plots"
|
|
os.makedirs(self.plots_dir, exist_ok=True)
|
|
|
|
def find_xmls_recursively(self, input_dir: str, xmls: List[str]) -> List[str]:
|
|
for entry in os.scandir(input_dir):
|
|
if entry.is_file() and entry.name.endswith(".xml"):
|
|
xmls.append(entry.path)
|
|
elif entry.is_dir():
|
|
self.find_xmls_recursively(entry.path, xmls)
|
|
return xmls
|
|
|
|
def find_xmls(self) -> List[str]:
|
|
return self.find_xmls_recursively(self.input_dir, [])
|
|
|
|
def get_doc_perplexity(self, filepath: str) -> float:
|
|
tree = ET.parse(filepath)
|
|
root = tree.getroot()
|
|
|
|
try:
|
|
perplexity = float(root.attrib["perplexity"])
|
|
except KeyError:
|
|
perplexity = -100
|
|
return perplexity
|
|
|
|
def parse_docs(self, xmls: List[str]) -> None:
|
|
docs = {}
|
|
for xml in tqdm(xmls):
|
|
perplexity = self.get_doc_perplexity(xml)
|
|
docs[xml] = perplexity
|
|
|
|
self.docs = docs
|
|
|
|
def sort_docs(self) -> None:
|
|
self.docs = dict(sorted(self.docs.items(), key=lambda item: item[1]))
|
|
|
|
def save_docs(self) -> None:
|
|
with open(os.path.join(self.plots_dir, "docs.json"), "w") as f:
|
|
json.dump(self.docs, f)
|
|
|
|
def histogram_of_perplexities(self) -> None:
|
|
_, values = zip(*self.docs.items())
|
|
|
|
# Create histogram
|
|
plt.figure(figsize=(10, 6))
|
|
plt.hist(values, bins=20, edgecolor="black")
|
|
plt.xlabel("Perplexity")
|
|
plt.ylabel("Number of Documents")
|
|
plt.title("Histogram of Document Perplexities")
|
|
plt.tight_layout()
|
|
# Save figure
|
|
plt.savefig(os.path.join(self.plots_dir, "histogram_of_perplexities.png"))
|
|
plt.close()
|
|
plt.clf()
|
|
|
|
def cumulative_distribution_of_perplexities(self) -> None:
|
|
_, values = zip(*self.docs.items())
|
|
|
|
# Create cumulative distribution plot
|
|
plt.figure(figsize=(10, 6))
|
|
plt.hist(values, bins=20, cumulative=True, edgecolor="black")
|
|
plt.xlabel("Perplexity")
|
|
plt.ylabel("Cumulative Number of Documents")
|
|
plt.title("Cumulative Distribution of Document Perplexities")
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
plt.savefig(
|
|
os.path.join(self.plots_dir, "cumulative_distribution_of_perplexities.png")
|
|
)
|
|
plt.close()
|
|
plt.clf()
|
|
|
|
def boxplot_of_perplexities(self) -> None:
|
|
_, values = zip(*self.docs.items())
|
|
|
|
# Create boxplot
|
|
plt.figure(figsize=(10, 6))
|
|
plt.boxplot(values)
|
|
plt.ylabel("Perplexity")
|
|
plt.title("Boxplot of Document Perplexities")
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
plt.savefig(os.path.join(self.plots_dir, "boxplot_of_perplexities.png"))
|
|
plt.close()
|
|
plt.clf()
|
|
|
|
def barplot_of_perplexities(self) -> None:
|
|
# Prepare data for visualization
|
|
labels, values = zip(*self.docs.items())
|
|
indexes = np.arange(len(labels))
|
|
|
|
# Create bar plot
|
|
plt.figure(figsize=(20, 10))
|
|
plt.bar(indexes, values, align="center")
|
|
plt.xticks(indexes, labels, rotation="vertical")
|
|
plt.ylabel("Perplexity")
|
|
plt.xlabel("Document")
|
|
plt.title("Perplexity of documents")
|
|
plt.tight_layout()
|
|
|
|
# Save figure
|
|
plt.savefig(os.path.join(self.plots_dir, "barplot_of_perplexities.png"))
|
|
plt.close()
|
|
plt.clf()
|
|
|
|
def pltt(self) -> None:
|
|
self.histogram_of_perplexities()
|
|
self.cumulative_distribution_of_perplexities()
|
|
self.boxplot_of_perplexities()
|
|
self.barplot_of_perplexities()
|
|
|
|
def main(self):
|
|
xmls = self.find_xmls()
|
|
self.parse_docs(xmls)
|
|
self.sort_docs()
|
|
self.save_docs()
|
|
self.pltt()
|
|
exit(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Process XML files.")
|
|
parser.add_argument(
|
|
"--input_dir", type=str, help="Input directory containing XML files."
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir", type=str, help="Output directory to save processed files."
|
|
)
|
|
parser.add_argument("--cpu", action="store_true", help="Force usage of CPU.")
|
|
args = parser.parse_args()
|
|
|
|
rater = Rate(args)
|
|
rater.main()
|