Gašper Spagnolo 2023-07-31 12:33:32 +02:00
parent 816280630d
commit aba634b82d
No known key found for this signature in database
GPG Key ID: 2EA0738CC1EFEEB7
7 changed files with 12558 additions and 10 deletions

2
.gitignore vendored
View File

@ -1,3 +1,3 @@
.venv/
__pycache__/
checkpoints/

172
README.md Normal file
View File

@ -0,0 +1,172 @@
```
{
"e": 8223193,
"n": 5007028,
"a": 4617384,
"i": 4617298,
"r": 3806902,
"o": 3506777,
"s": 3286940,
"t": 2994776,
"d": 2877282,
"l": 2330763,
"u": 1676899,
"h": 1566659,
"g": 1525408,
"m": 1402339,
"v": 1304891,
"k": 1268507,
"c": 1191632,
"j": 1138996,
"b": 1102296,
"p": 1097955,
"z": 927699,
"f": 467568,
"č": 443139,
".": 438490,
"w": 399028,
"š": 290995,
"ž": 250267,
"ü": 212082,
"S": 194575,
"A": 192510,
"ä": 169917,
"1": 157447,
"0": 145954,
"B": 141047,
"L": 137980,
"D": 135820,
"P": 135278,
"G": 130277,
"ß": 125683,
"V": 116265,
"K": 94573,
"I": 90586,
"H": 83761,
"8": 82531,
"W": 79459,
"2": 79382,
"E": 77487,
"R": 70841,
"ö": 70604,
"M": 70215,
"T": 68813,
"N": 62789,
":": 59231,
"3": 58524,
")": 58035,
"(": 57753,
"F": 56410,
"5": 55914,
"Z": 53670,
"—": 52400,
"4": 50003,
"9": 49362,
"6": 47439,
"J": 45832,
"O": 43816,
"7": 41612,
"U": 39562,
"§": 24163,
"C": 23519,
"!": 14133,
",": 12041,
"Ž": 10802,
"„": 10249,
"Š": 9210,
"Č": 7467,
"y": 6853,
"'": 5763,
"x": 5667,
"%": 5265,
"X": 4657,
"©": 4404,
"q": 3761,
"“": 3749,
"»": 3394,
"«": 3270,
"°": 2976,
"/": 2151,
";": 1889,
"Ü": 1693,
"Q": 1521,
"Ä": 1293,
"?": 1107,
"=": 1099,
"-": 1063,
"®": 1045,
"■": 806,
"\"": 791,
"£": 695,
"": 694,
"|": 589,
"•": 560,
"_": 460,
"Ö": 427,
"": 306,
"í": 304,
"é": 297,
"¿": 259,
"\\": 193,
"™": 164,
"Y": 138,
"]": 122,
"á": 91,
"[": 79,
"ì": 72,
"è": 61,
"Ì": 53,
"ó": 51,
"□": 50,
"à": 46,
"¡": 46,
"~": 45,
"É": 34,
"ò": 29,
"+": 27,
"đ": 22,
"Í": 20,
"€": 18,
"}": 18,
"ï": 16,
"Ê": 15,
"ê": 15,
"î": 13,
"а": 13,
"б": 13,
"{": 13,
"È": 12,
"ÿ": 11,
"Î": 9,
"â": 9,
"”": 9,
"ú": 7,
"±": 6,
"ù": 6,
"।": 5,
"♦": 5,
"ë": 5,
"À": 4,
"Ć": 4,
"Ï": 4,
"Æ": 4,
"Đ": 3,
"Ç": 3,
"ñ": 3,
"Ó": 2,
"Ò": 2,
"ć": 2,
"►": 2,
"Û": 2,
"Á": 2,
"Ù": 1,
"✓": 1,
"Â": 1,
"Ú": 1,
"Ë": 1,
"æ": 1,
"œ": 1,
"¥": 1,
"ô": 1
}
```

4
dataset/gt.txt Normal file
View File

@ -0,0 +1,4 @@
0000 Soll der Landtag als dermaliger verfassungsmäßiger Vertreter des Landes den Faden der Verhandlung wieder aufnehmen oder darf er ohne sich dem Vorwürfe träger Sorglosigkeit auszusetzen zuwarten bis die Staatsverwaltung etwa bei den veränderten Verhältnissen auch diese karge Dotation dem Lande entziehe unter dem Vorgeben daß dieselbe nur eine int Gnadenwege den v o r b c st a n d c n e n Stünden gewährte Aushilfe gewesen sei und mit diesen zugleich aufgehört habe
0001 5. Ant. Graf v. Auersperg
0002 Denn haben wir kein Anlehen so müssen wir die 36 % jetzt zahlen und bis auf 55 ° 0 somit mit allen andern LandesUmlagcn bis auf 70 ° 0 die Steuerzuschläge hinauf treiben
0003 Ich stelle in dieser Beziehung die Unterstütznngsfrage

12263
dataset/hason_out.json Normal file

File diff suppressed because it is too large Load Diff

19
dl.py
View File

@ -6,11 +6,14 @@ from nlpaug.augmenter.char import OcrAug
from nlpaug.augmenter.word import RandomWordAug
from sklearn.model_selection import train_test_split
import pandas as pd
import json
class TextDataset(Dataset):
def __init__(self, path, max_length, buffer_size):
self.data = pd.read_csv(path, delimiter="\t", header=None)
with open(path, "r") as f:
data = json.load(f)
self.data = pd.DataFrame(data)
self.max_length = max_length
self.buffer_size = buffer_size
@ -32,11 +35,11 @@ class TextDataset(Dataset):
return len(self.data)
def __getitem__(self, idx):
index, sentence = self.data.iloc[idx]
aug_sentence = self.aug_char.augment(sentence)
aug_sentence = self.aug_delete.augment(aug_sentence)
aug_sentence = aug_sentence[0]
return sentence, aug_sentence
sentence, corrected = self.data.iloc[idx]
#aug_sentence = self.aug_char.augment(sentence)
#aug_sentence = self.aug_delete.augment(aug_sentence)
#aug_sentence = aug_sentence[0]
return corrected, sentence
def load_dataset(path, max_length, buffer_size, batch_size, test_ratio=0.2):
@ -60,10 +63,10 @@ def load_dataset(path, max_length, buffer_size, batch_size, test_ratio=0.2):
def test():
train_dataloader, test_dataloader = load_dataset(
"../datasets/deu_mixed-typical_2011_1M/deu_mixed-typical_2011_1M-sentences.txt",
"./dataset/hason_out.json",
100,
100,
32,
6,
test_ratio=0.2,
)
for batch in train_dataloader:

2
ft.py
View File

@ -121,7 +121,7 @@ class FT:
def train(self):
train_dataloader, test_dataloader = load_dataset(
"../datasets/deu_mixed-typical_2011_1M/deu_mixed-typical_2011_1M-sentences.txt",
"./dataset/hason_out.json",
100,
100,
32,

106
run.py Normal file
View File

@ -0,0 +1,106 @@
#! /usr/bin/env python3
from transformers import BartForConditionalGeneration, BartTokenizer, AdamW
import torch
from tqdm import tqdm
import os
from torch.utils.data import Dataset
import pandas as pd
class DL(Dataset):
def __init__(self, path, max_length, buffer_size):
self.data = pd.read_csv(path, delimiter="\t", header=None)
self.max_length = max_length
self.buffer_size = buffer_size
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
ix, text = self.data.iloc[idx]
return text
class FT:
def __init__(self):
# Enable cudnn optimizations
torch.backends.cudnn.benchmark = True
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load tokenizer and model
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
self.model.to(self.device)
# set up optimizer
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
self.load_checkpoint()
try:
from torch.cuda.amp import GradScaler, autocast
self.scaler = GradScaler()
except ImportError:
class autocast:
def __enter__(self):
pass
def __exit__(self, *args):
pass
self.scaler = None # We won't use a scaler if we don't have Amp
def test_model(self, dataloader):
self.model.eval()
predictions = [] # List to store the generated text outputs
print("Testing model...")
for batch in tqdm(dataloader):
with torch.no_grad():
inputs = self.tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
)
inputs.to(self.device)
outputs = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=200, # Set the maximum length of the generated output
num_beams=4, # Number of beams for beam search (optional)
early_stopping=True, # Stop generation when all beams are finished (optional)
)
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
predictions.append(generated_text)
return predictions
def load_checkpoint(self):
checkpoint_path = "./checkpoints/ft_0.pt"
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
self.current_epoch = checkpoint["epoch"]
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
print(f"Loaded checkpoint from epoch {self.current_epoch}")
def validate(self):
dataloader = torch.utils.data.DataLoader(
DL(
path="./dataset/gt.txt",
max_length=300,
buffer_size=100,
),
batch_size=1,
)
outputs = self.test_model(dataloader)
print(outputs)
if __name__ == "__main__":
validator = FT()
validator.validate()