|
import torch |
|
import torch.nn.functional as F |
|
from transformers import BertTokenizer, BertForTokenClassification |
|
import re |
|
import string |
|
|
|
|
|
def preprocess_input_text(text): |
|
""" |
|
This function adds a [MASK] token after each word, inserts a space before every punctuation mark, |
|
and converts all words to lowercase. |
|
|
|
It returns the original words from the input text along with the preprocessed version of the input text. |
|
""" |
|
text = re.sub(r'([' + string.punctuation + '])', r' \1', text) |
|
text = re.sub(' +', ' ', text) |
|
|
|
words = text.split(" ") |
|
|
|
text = text.lower() |
|
|
|
output = [] |
|
|
|
for word in text.split(" "): |
|
output.append(word) |
|
output.append("[MASK]") |
|
|
|
return words, " ".join(output) |
|
|
|
|
|
def predict_using_trained_model_old(input_text, model_dir, device): |
|
""" |
|
This function loads a model and predicts whether each word in the input text is correct or incorrect. |
|
|
|
The output is the input text, where each word is followed by a label indicating whether the word is correct (0) or incorrect (1). |
|
""" |
|
|
|
words, input_text = preprocess_input_text(input_text) |
|
|
|
tokenizer = BertTokenizer.from_pretrained(model_dir) |
|
model = BertForTokenClassification.from_pretrained(model_dir, num_labels=2) |
|
|
|
model.to(device) |
|
|
|
tokenized_inputs = tokenizer(input_text, max_length=128, padding='max_length', truncation=True, return_tensors="pt") |
|
input_ids = tokenized_inputs["input_ids"].to(device) |
|
attention_mask = tokenized_inputs["attention_mask"].to(device) |
|
|
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
|
|
predictions = torch.argmax(logits, dim=-1).squeeze().cpu().numpy() |
|
|
|
tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze().cpu().numpy()) |
|
|
|
model_output = [] |
|
mask_index = 0 |
|
|
|
for token, prediction in zip(tokens, predictions): |
|
if token == "[MASK]": |
|
model_output.append(str(prediction)) |
|
mask_index += 1 |
|
elif token != "[CLS]" and token != "[SEP]" and token != "[PAD]": |
|
model_output.append(words[mask_index]) |
|
|
|
return " ".join(model_output) |
|
|
|
|
|
if __name__ == '__main__': |
|
input_text = "Model u tekstu prepoznije riječi u kojima se nalazaju pogreške." |
|
model_dir = "." |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
elif torch.backends.mps.is_available(): |
|
device = torch.device("mps") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
print(f"Using device: {device}") |
|
|
|
model_output_text = predict_using_trained_model_old(input_text, model_dir, device) |
|
|
|
print(f"Model output: {model_output_text}") |
|
|