{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "89c94977", "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import snapshot_download\n", "data_folder = snapshot_download(\"fxtentacle/tevr-token-entropy-predictor-de\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "a48a49d6", "metadata": {}, "outputs": [], "source": [ "from transformers import T5ForConditionalGeneration\n", "model = T5ForConditionalGeneration.from_pretrained(data_folder)\n", "model.to('cuda')\n", "model.eval()\n", "None" ] }, { "cell_type": "code", "execution_count": 3, "id": "eed8bfc3", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "def text_to_cross_entropy(text):\n", " ttext = torch.tensor([[0]+list(text.encode('UTF-8'))],dtype=torch.int64).to('cuda')\n", " tone = torch.tensor([[1]],dtype=torch.int32).to('cuda')\n", " logits = model.forward(input_ids=tone, attention_mask=tone, decoder_input_ids=ttext, return_dict=False)[0].detach()\n", " cross_entropy = torch.nn.functional.cross_entropy(input=logits[0][:-1], target=ttext[0][1:], reduction='none').detach().cpu().numpy()\n", " return cross_entropy" ] }, { "cell_type": "code", "execution_count": 4, "id": "8ec8cf8d", "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "sys.path.append(data_folder)\n", "from text_tokenizer import HajoTextTokenizer" ] }, { "cell_type": "code", "execution_count": 5, "id": "37165805", "metadata": {}, "outputs": [], "source": [ "tokenizer_file = 'text-tokenizer-de-4m.txt'\n", "text_tokenizer = HajoTextTokenizer(data_folder+'/'+tokenizer_file)" ] }, { "cell_type": "code", "execution_count": 6, "id": "73e55343", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['die', ' ', 'k', 'at', 'ze', ' ', 'ist', ' ', 'n', 'ied', 'lich']\n", "[3.3762913048267365, 3.3762913048267365, 3.3762913048267365, 0.29695791006088257, 4.193424224853516, 2.3430762887001038, 2.3430762887001038, 2.8417416363954544, 2.8417416363954544, 1.1227068901062012, 2.017452405144771, 2.017452405144771, 2.017452405144771, 0.0016304069431498647, 2.580254554748535, 2.3091587026913962, 2.3091587026913962, 2.3091587026913962, 1.0126478232632508, 1.0126478232632508, 1.0126478232632508, 1.0126478232632508]\n" ] } ], "source": [ "text = \"die katze ist niedlich\"\n", "cross_entropy = text_to_cross_entropy(text)\n", "\n", "tokens = text_tokenizer.encode(text)\n", "tokens = [text_tokenizer.all_tokens[t] for t in tokens]\n", "print(tokens)\n", "token_sums = []\n", "token_sums2 = []\n", "for t in tokens:\n", " ce = sum(cross_entropy[len(token_sums):len(token_sums)+len(t)])\n", " for r in range(len(t)): token_sums.append(ce / len(t))\n", " token_sums2.append(ce)\n", "print(token_sums)" ] }, { "cell_type": "code", "execution_count": 7, "id": "e61e00aa", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
(1)die katze ist niedlich
(2)1.01.01.01.01.01.01.01.01.01.01.01.01.01.01.01.01.01.01.01.01.01.0σ²=0.0
(3)8.91.00.20.34.21.63.15.40.31.13.03.00.00.02.60.64.41.94.00.00.00.0σ²=5.0
(4)die katze ist niedlich
(5)10.10.34.24.75.71.16.10.02.66.94.1
(6)3.43.43.40.34.22.32.32.82.81.12.02.02.00.02.62.32.32.31.01.01.01.0σ²=1.1
" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "html = ''\n", "html += ''+''.join([f'' for c in list(text)])+''\n", "html += ''+''.join([''.format(v) for v in cross_entropy])+''.format(np.var([1.0 for v in cross_entropy]))+''\n", "html += ''+''.join([''.format(v) for v in cross_entropy])+''.format(np.var(cross_entropy))+''\n", "html += ''+''.join([f'' for t in tokens])+''\n", "html += ''+''.join([f'' for i,t in enumerate(tokens)])+''\n", "html += ''+''.join([''.format(v) for v in token_sums])+''.format(np.var(token_sums))+''\n", "html += '
(1){c}
(2)1.0σ²={:3.1f}
(3){:3.1f}σ²={:3.1f}
(4){t}
(5){\"{:3.1f}\".format(token_sums2[i])}
(6){:3.1f}σ²={:3.1f}
'\n", "\n", "import IPython\n", "IPython.display.HTML(html)" ] }, { "cell_type": "code", "execution_count": 8, "id": "dcafdcab", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ", , , chen, sche, lich, isch, icht, iche, eine, rden, tion, urde, haft, eich, rung, chte, ssen, chaf, nder, tlic, tung, eite, iert, sich, ngen, erde, scha, nden, unge, lung, mmen, eren, ende, inde, erun, sten, iese, igen, erte, iner, tsch, keit, der, die, ter, und, ein, ist, den, ten, ber, ver, sch, ung, ste, ent, ach, nte, auf, ben, eit, des, ers, aus, das, von, ren, gen, nen, lle, hre, mit, iel, uch, lte, ann, lie, men, dem, and, ind, als, sta, elt, ges, tte, ern, wir, ell, war, ere, rch, abe, len, ige, ied, ger, nnt, wei, ele, och, sse, end, all, ahr, bei, sie, ede, ion, ieg, ege, auc, che, rie, eis, vor, her, ang, für, ass, uss, tel, er, in, ge, en, st, ie, an, te, be, re, zu, ar, es, ra, al, or, ch, et, ei, un, le, rt, se, is, ha, we, at, me, ne, ur, he, au, ro, ti, li, ri, eh, im, ma, tr, ig, el, um, la, am, de, so, ol, tz, il, on, it, sc, sp, ko, na, pr, ni, si, fe, wi, ns, ke, ut, da, gr, eu, mi, hr, ze, hi, ta, ss, ng, sa, us, ba, ck, em, kt, ka, ve, fr, bi, wa, ah, gt, di, ab, fo, to, rk, as, ag, gi, hn, s, t, n, m, r, l, f, e, a, b, d, h, k, g, o, i, u, w, p, z, ä, ü, v, ö, j, c, y, x, q, á, í, ō, ó, š, é, č, ?\n" ] } ], "source": [ "from text_tokenizer import HajoTextTokenizer\n", "text_tokenizer = HajoTextTokenizer(data_folder+'/'+tokenizer_file)\n", "tt = text_tokenizer.all_tokens\n", "print(', '.join(tt))" ] }, { "cell_type": "code", "execution_count": null, "id": "b87b7fd0", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 5 }