Fix the pasring script

master
Gašper Spagnolo 2023-08-03 13:36:16 +02:00
parent 72163a9016
commit cb3c07b428
No known key found for this signature in database
GPG Key ID: 2EA0738CC1EFEEB7
4 changed files with 14 additions and 2 deletions

View File

@ -7,6 +7,7 @@ from tqdm import tqdm
import argparse import argparse
from typing import List from typing import List
import warnings import warnings
import re
class XMLProcessor: class XMLProcessor:
@ -17,6 +18,7 @@ class XMLProcessor:
self.model: XLMRobertaForMaskedLM = None self.model: XLMRobertaForMaskedLM = None
self.tokenizer: XLMRobertaTokenizer = None self.tokenizer: XLMRobertaTokenizer = None
self.device: torch.device = None self.device: torch.device = None
self.documents: dict = {}
def prepare_model(self) -> None: def prepare_model(self) -> None:
model_name: str = "xlm-roberta-base" model_name: str = "xlm-roberta-base"
@ -29,6 +31,14 @@ class XMLProcessor:
self.model = self.model.to(self.device) self.model = self.model.to(self.device)
self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_name) self.tokenizer = XLMRobertaTokenizer.from_pretrained(model_name)
def contains_number(self, string):
# Regular expression to match any number
pattern = r'\d+'
# Search for the pattern in the string
match = re.search(pattern, string)
# Return True if a number is found, otherwise False
return bool(match)
def compute_perplexity(self, text: str) -> float: def compute_perplexity(self, text: str) -> float:
max_length = 512 # XLM-Roberta's maximum sequence length max_length = 512 # XLM-Roberta's maximum sequence length
@ -53,7 +63,7 @@ class XMLProcessor:
inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()} inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()}
outputs = self.model(**inputs, labels=inputs["input_ids"]) outputs = self.model(**inputs, labels=inputs["input_ids"])
total_loss += torch.exp(outputs.loss).item() total_loss += outputs.loss.item()
total_count += 1 total_count += 1
# Process the remaining text # Process the remaining text
@ -101,6 +111,8 @@ class XMLProcessor:
num_sentences = 0 num_sentences = 0
for s in seg.findall(".//default:s", ns): for s in seg.findall(".//default:s", ns):
sentence = " ".join([w.text for w in s.findall(".//default:w", ns)]) sentence = " ".join([w.text for w in s.findall(".//default:w", ns)])
if len(sentence) < 20 and self.contains_number(sentence):
continue
perplexity = self.compute_perplexity(sentence) perplexity = self.compute_perplexity(sentence)
s.set("perplexity", str(perplexity)) s.set("perplexity", str(perplexity))
segment_perplexity += perplexity segment_perplexity += perplexity

Binary file not shown.

Before

Width:  |  Height:  |  Size: 543 KiB

After

Width:  |  Height:  |  Size: 529 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 18 KiB

After

Width:  |  Height:  |  Size: 18 KiB

File diff suppressed because one or more lines are too long