main
parent
816280630d
commit
aba634b82d
|
@ -1,3 +1,3 @@
|
||||||
.venv/
|
.venv/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
checkpoints/
|
||||||
|
|
|
@ -0,0 +1,172 @@
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"e": 8223193,
|
||||||
|
"n": 5007028,
|
||||||
|
"a": 4617384,
|
||||||
|
"i": 4617298,
|
||||||
|
"r": 3806902,
|
||||||
|
"o": 3506777,
|
||||||
|
"s": 3286940,
|
||||||
|
"t": 2994776,
|
||||||
|
"d": 2877282,
|
||||||
|
"l": 2330763,
|
||||||
|
"u": 1676899,
|
||||||
|
"h": 1566659,
|
||||||
|
"g": 1525408,
|
||||||
|
"m": 1402339,
|
||||||
|
"v": 1304891,
|
||||||
|
"k": 1268507,
|
||||||
|
"c": 1191632,
|
||||||
|
"j": 1138996,
|
||||||
|
"b": 1102296,
|
||||||
|
"p": 1097955,
|
||||||
|
"z": 927699,
|
||||||
|
"f": 467568,
|
||||||
|
"č": 443139,
|
||||||
|
".": 438490,
|
||||||
|
"w": 399028,
|
||||||
|
"š": 290995,
|
||||||
|
"ž": 250267,
|
||||||
|
"ü": 212082,
|
||||||
|
"S": 194575,
|
||||||
|
"A": 192510,
|
||||||
|
"ä": 169917,
|
||||||
|
"1": 157447,
|
||||||
|
"0": 145954,
|
||||||
|
"B": 141047,
|
||||||
|
"L": 137980,
|
||||||
|
"D": 135820,
|
||||||
|
"P": 135278,
|
||||||
|
"G": 130277,
|
||||||
|
"ß": 125683,
|
||||||
|
"V": 116265,
|
||||||
|
"K": 94573,
|
||||||
|
"I": 90586,
|
||||||
|
"H": 83761,
|
||||||
|
"8": 82531,
|
||||||
|
"W": 79459,
|
||||||
|
"2": 79382,
|
||||||
|
"E": 77487,
|
||||||
|
"R": 70841,
|
||||||
|
"ö": 70604,
|
||||||
|
"M": 70215,
|
||||||
|
"T": 68813,
|
||||||
|
"N": 62789,
|
||||||
|
":": 59231,
|
||||||
|
"3": 58524,
|
||||||
|
")": 58035,
|
||||||
|
"(": 57753,
|
||||||
|
"F": 56410,
|
||||||
|
"5": 55914,
|
||||||
|
"Z": 53670,
|
||||||
|
"—": 52400,
|
||||||
|
"4": 50003,
|
||||||
|
"9": 49362,
|
||||||
|
"6": 47439,
|
||||||
|
"J": 45832,
|
||||||
|
"O": 43816,
|
||||||
|
"7": 41612,
|
||||||
|
"U": 39562,
|
||||||
|
"§": 24163,
|
||||||
|
"C": 23519,
|
||||||
|
"!": 14133,
|
||||||
|
",": 12041,
|
||||||
|
"Ž": 10802,
|
||||||
|
"„": 10249,
|
||||||
|
"Š": 9210,
|
||||||
|
"Č": 7467,
|
||||||
|
"y": 6853,
|
||||||
|
"'": 5763,
|
||||||
|
"x": 5667,
|
||||||
|
"%": 5265,
|
||||||
|
"X": 4657,
|
||||||
|
"©": 4404,
|
||||||
|
"q": 3761,
|
||||||
|
"“": 3749,
|
||||||
|
"»": 3394,
|
||||||
|
"«": 3270,
|
||||||
|
"°": 2976,
|
||||||
|
"/": 2151,
|
||||||
|
";": 1889,
|
||||||
|
"Ü": 1693,
|
||||||
|
"Q": 1521,
|
||||||
|
"Ä": 1293,
|
||||||
|
"?": 1107,
|
||||||
|
"=": 1099,
|
||||||
|
"-": 1063,
|
||||||
|
"®": 1045,
|
||||||
|
"■": 806,
|
||||||
|
"\"": 791,
|
||||||
|
"£": 695,
|
||||||
|
"’": 694,
|
||||||
|
"|": 589,
|
||||||
|
"•": 560,
|
||||||
|
"_": 460,
|
||||||
|
"Ö": 427,
|
||||||
|
"‘": 306,
|
||||||
|
"í": 304,
|
||||||
|
"é": 297,
|
||||||
|
"¿": 259,
|
||||||
|
"\\": 193,
|
||||||
|
"™": 164,
|
||||||
|
"Y": 138,
|
||||||
|
"]": 122,
|
||||||
|
"á": 91,
|
||||||
|
"[": 79,
|
||||||
|
"ì": 72,
|
||||||
|
"è": 61,
|
||||||
|
"Ì": 53,
|
||||||
|
"ó": 51,
|
||||||
|
"□": 50,
|
||||||
|
"à": 46,
|
||||||
|
"¡": 46,
|
||||||
|
"~": 45,
|
||||||
|
"É": 34,
|
||||||
|
"ò": 29,
|
||||||
|
"+": 27,
|
||||||
|
"đ": 22,
|
||||||
|
"Í": 20,
|
||||||
|
"€": 18,
|
||||||
|
"}": 18,
|
||||||
|
"ï": 16,
|
||||||
|
"Ê": 15,
|
||||||
|
"ê": 15,
|
||||||
|
"î": 13,
|
||||||
|
"а": 13,
|
||||||
|
"б": 13,
|
||||||
|
"{": 13,
|
||||||
|
"È": 12,
|
||||||
|
"ÿ": 11,
|
||||||
|
"Î": 9,
|
||||||
|
"â": 9,
|
||||||
|
"”": 9,
|
||||||
|
"ú": 7,
|
||||||
|
"±": 6,
|
||||||
|
"ù": 6,
|
||||||
|
"।": 5,
|
||||||
|
"♦": 5,
|
||||||
|
"ë": 5,
|
||||||
|
"À": 4,
|
||||||
|
"Ć": 4,
|
||||||
|
"Ï": 4,
|
||||||
|
"Æ": 4,
|
||||||
|
"Đ": 3,
|
||||||
|
"Ç": 3,
|
||||||
|
"ñ": 3,
|
||||||
|
"Ó": 2,
|
||||||
|
"Ò": 2,
|
||||||
|
"ć": 2,
|
||||||
|
"►": 2,
|
||||||
|
"Û": 2,
|
||||||
|
"Á": 2,
|
||||||
|
"Ù": 1,
|
||||||
|
"✓": 1,
|
||||||
|
"Â": 1,
|
||||||
|
"Ú": 1,
|
||||||
|
"Ë": 1,
|
||||||
|
"æ": 1,
|
||||||
|
"œ": 1,
|
||||||
|
"¥": 1,
|
||||||
|
"ô": 1
|
||||||
|
}
|
||||||
|
```
|
|
@ -0,0 +1,4 @@
|
||||||
|
0000 Soll der Landtag als dermaliger verfassungsmäßiger Vertreter des Landes den Faden der Verhandlung wieder aufnehmen oder darf er ohne sich dem Vorwürfe träger Sorglosigkeit auszusetzen zuwarten bis die Staatsverwaltung etwa bei den veränderten Verhältnissen auch diese karge Dotation dem Lande entziehe unter dem Vorgeben daß dieselbe nur eine int Gnadenwege den v o r b c st a n d c n e n Stünden gewährte Aushilfe gewesen sei und mit diesen zugleich aufgehört habe
|
||||||
|
0001 5. Ant. Graf v. Auersperg
|
||||||
|
0002 Denn haben wir kein Anlehen so müssen wir die 36 % jetzt zahlen und bis auf 55 ° 0 somit mit allen andern LandesUmlagcn bis auf 70 ° 0 die Steuerzuschläge hinauf treiben
|
||||||
|
0003 Ich stelle in dieser Beziehung die Unterstütznngsfrage
|
File diff suppressed because it is too large
Load Diff
19
dl.py
19
dl.py
|
@ -6,11 +6,14 @@ from nlpaug.augmenter.char import OcrAug
|
||||||
from nlpaug.augmenter.word import RandomWordAug
|
from nlpaug.augmenter.word import RandomWordAug
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
class TextDataset(Dataset):
|
class TextDataset(Dataset):
|
||||||
def __init__(self, path, max_length, buffer_size):
|
def __init__(self, path, max_length, buffer_size):
|
||||||
self.data = pd.read_csv(path, delimiter="\t", header=None)
|
with open(path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
self.data = pd.DataFrame(data)
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.buffer_size = buffer_size
|
self.buffer_size = buffer_size
|
||||||
|
|
||||||
|
@ -32,11 +35,11 @@ class TextDataset(Dataset):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
index, sentence = self.data.iloc[idx]
|
sentence, corrected = self.data.iloc[idx]
|
||||||
aug_sentence = self.aug_char.augment(sentence)
|
#aug_sentence = self.aug_char.augment(sentence)
|
||||||
aug_sentence = self.aug_delete.augment(aug_sentence)
|
#aug_sentence = self.aug_delete.augment(aug_sentence)
|
||||||
aug_sentence = aug_sentence[0]
|
#aug_sentence = aug_sentence[0]
|
||||||
return sentence, aug_sentence
|
return corrected, sentence
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(path, max_length, buffer_size, batch_size, test_ratio=0.2):
|
def load_dataset(path, max_length, buffer_size, batch_size, test_ratio=0.2):
|
||||||
|
@ -60,10 +63,10 @@ def load_dataset(path, max_length, buffer_size, batch_size, test_ratio=0.2):
|
||||||
|
|
||||||
def test():
|
def test():
|
||||||
train_dataloader, test_dataloader = load_dataset(
|
train_dataloader, test_dataloader = load_dataset(
|
||||||
"../datasets/deu_mixed-typical_2011_1M/deu_mixed-typical_2011_1M-sentences.txt",
|
"./dataset/hason_out.json",
|
||||||
100,
|
100,
|
||||||
100,
|
100,
|
||||||
32,
|
6,
|
||||||
test_ratio=0.2,
|
test_ratio=0.2,
|
||||||
)
|
)
|
||||||
for batch in train_dataloader:
|
for batch in train_dataloader:
|
||||||
|
|
2
ft.py
2
ft.py
|
@ -121,7 +121,7 @@ class FT:
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
train_dataloader, test_dataloader = load_dataset(
|
train_dataloader, test_dataloader = load_dataset(
|
||||||
"../datasets/deu_mixed-typical_2011_1M/deu_mixed-typical_2011_1M-sentences.txt",
|
"./dataset/hason_out.json",
|
||||||
100,
|
100,
|
||||||
100,
|
100,
|
||||||
32,
|
32,
|
||||||
|
|
|
@ -0,0 +1,106 @@
|
||||||
|
#! /usr/bin/env python3
|
||||||
|
from transformers import BartForConditionalGeneration, BartTokenizer, AdamW
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
import os
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
class DL(Dataset):
|
||||||
|
def __init__(self, path, max_length, buffer_size):
|
||||||
|
self.data = pd.read_csv(path, delimiter="\t", header=None)
|
||||||
|
self.max_length = max_length
|
||||||
|
self.buffer_size = buffer_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
ix, text = self.data.iloc[idx]
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class FT:
|
||||||
|
def __init__(self):
|
||||||
|
# Enable cudnn optimizations
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
# load tokenizer and model
|
||||||
|
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
||||||
|
self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
|
||||||
|
self.model.to(self.device)
|
||||||
|
# set up optimizer
|
||||||
|
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
|
||||||
|
self.load_checkpoint()
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch.cuda.amp import GradScaler, autocast
|
||||||
|
|
||||||
|
self.scaler = GradScaler()
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
class autocast:
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.scaler = None # We won't use a scaler if we don't have Amp
|
||||||
|
|
||||||
|
def test_model(self, dataloader):
|
||||||
|
self.model.eval()
|
||||||
|
predictions = [] # List to store the generated text outputs
|
||||||
|
print("Testing model...")
|
||||||
|
for batch in tqdm(dataloader):
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
batch,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=512,
|
||||||
|
)
|
||||||
|
inputs.to(self.device)
|
||||||
|
outputs = self.model.generate(
|
||||||
|
input_ids=inputs["input_ids"],
|
||||||
|
attention_mask=inputs["attention_mask"],
|
||||||
|
max_length=200, # Set the maximum length of the generated output
|
||||||
|
num_beams=4, # Number of beams for beam search (optional)
|
||||||
|
early_stopping=True, # Stop generation when all beams are finished (optional)
|
||||||
|
)
|
||||||
|
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||||
|
predictions.append(generated_text)
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
def load_checkpoint(self):
|
||||||
|
checkpoint_path = "./checkpoints/ft_0.pt"
|
||||||
|
if os.path.exists(checkpoint_path):
|
||||||
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
self.current_epoch = checkpoint["epoch"]
|
||||||
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
|
print(f"Loaded checkpoint from epoch {self.current_epoch}")
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
DL(
|
||||||
|
path="./dataset/gt.txt",
|
||||||
|
max_length=300,
|
||||||
|
buffer_size=100,
|
||||||
|
),
|
||||||
|
batch_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.test_model(dataloader)
|
||||||
|
print(outputs)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
validator = FT()
|
||||||
|
validator.validate()
|
Loading…
Reference in New Issue