Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -60,13 +60,21 @@ def prepare_mlm_sample(text, mask_ratio=0.15):
|
|
60 |
global masked_indices, masked_tokens, original_text
|
61 |
|
62 |
tokens = tokenizer.tokenize(text)
|
|
|
|
|
63 |
# Only mask whole words, not special tokens or punctuation
|
64 |
maskable_indices = [i for i, token in enumerate(tokens)
|
65 |
if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
|
66 |
and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
|
67 |
|
|
|
|
|
|
|
68 |
# Calculate how many tokens to mask, but ensure at least 1 and at most 8
|
|
|
69 |
num_to_mask = max(1, min(8, int(len(maskable_indices) * mask_ratio)))
|
|
|
|
|
70 |
# Randomly select indices to mask
|
71 |
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
|
72 |
# Sort indices to ensure they're in order
|
@@ -101,15 +109,20 @@ def prepare_ntp_sample(text, cut_ratio=0.3):
|
|
101 |
# Tokenize text to ensure reasonable cutting
|
102 |
tokens = tokenizer.tokenize(text)
|
103 |
|
|
|
|
|
|
|
|
|
104 |
# Ensure we have enough tokens
|
105 |
if len(tokens) < 5:
|
106 |
return text, "" # Return original if too short
|
107 |
|
108 |
-
# Calculate cutoff point
|
109 |
-
# But make sure we have at least 3 tokens visible and 1 token hidden
|
110 |
cutoff = max(3, int(len(tokens) * (1 - cut_ratio)))
|
111 |
cutoff = min(cutoff, len(tokens) - 1) # Ensure there's at least 1 token to predict
|
112 |
|
|
|
|
|
113 |
# Get the visible part
|
114 |
visible_tokens = tokens[:cutoff]
|
115 |
|
@@ -120,15 +133,24 @@ def prepare_ntp_sample(text, cut_ratio=0.3):
|
|
120 |
visible_text = tokenizer.convert_tokens_to_string(visible_tokens)
|
121 |
hidden_text = tokenizer.convert_tokens_to_string(hidden_tokens)
|
122 |
|
|
|
|
|
|
|
123 |
return visible_text, hidden_text
|
124 |
|
125 |
def get_new_sample(task, mask_ratio=0.15):
|
126 |
"""Get a new text sample based on the task."""
|
127 |
-
global current_sample, masked_text, masked_indices, masked_tokens, original_text, ntp_state
|
|
|
|
|
|
|
128 |
|
129 |
# Select a random sample
|
130 |
current_sample = random.choice(data_samples)
|
131 |
|
|
|
|
|
|
|
132 |
if task == "mlm":
|
133 |
# Prepare MLM sample
|
134 |
masked_text, masked_indices, masked_tokens = prepare_mlm_sample(current_sample, mask_ratio)
|
@@ -373,7 +395,7 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
373 |
)
|
374 |
|
375 |
with gr.Row():
|
376 |
-
new_button = gr.Button("New Sample")
|
377 |
reset_button = gr.Button("Reset Stats")
|
378 |
|
379 |
# Consolidated input area - only one visible at a time
|
@@ -433,8 +455,21 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
433 |
outputs=[mlm_instructions, ntp_instructions, answer_input, mask_count]
|
434 |
)
|
435 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
# Update the sample text and also update the mask count
|
437 |
def new_sample_with_count(mask_ratio_pct, task):
|
|
|
438 |
ratio = float(mask_ratio_pct) / 100.0
|
439 |
sample = get_new_sample(task, ratio)
|
440 |
mask_count_text = ""
|
@@ -442,8 +477,10 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
|
|
442 |
if task == "mlm":
|
443 |
count = len(masked_tokens)
|
444 |
mask_count_text = f"**Number of [MASK] tokens to guess: {count}**"
|
|
|
445 |
else:
|
446 |
mask_count_text = "**Next Token Prediction mode - guess one token at a time**"
|
|
|
447 |
|
448 |
return sample, mask_count_text, ""
|
449 |
|
|
|
60 |
global masked_indices, masked_tokens, original_text
|
61 |
|
62 |
tokens = tokenizer.tokenize(text)
|
63 |
+
print(f"Text length: {len(text)} characters, {len(tokens)} tokens")
|
64 |
+
|
65 |
# Only mask whole words, not special tokens or punctuation
|
66 |
maskable_indices = [i for i, token in enumerate(tokens)
|
67 |
if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
|
68 |
and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
|
69 |
|
70 |
+
print(f"Maskable indices count: {len(maskable_indices)}")
|
71 |
+
print(f"Mask ratio: {mask_ratio}")
|
72 |
+
|
73 |
# Calculate how many tokens to mask, but ensure at least 1 and at most 8
|
74 |
+
# Use the maskable_indices length with the ratio
|
75 |
num_to_mask = max(1, min(8, int(len(maskable_indices) * mask_ratio)))
|
76 |
+
print(f"Number of tokens to mask: {num_to_mask}")
|
77 |
+
|
78 |
# Randomly select indices to mask
|
79 |
indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
|
80 |
# Sort indices to ensure they're in order
|
|
|
109 |
# Tokenize text to ensure reasonable cutting
|
110 |
tokens = tokenizer.tokenize(text)
|
111 |
|
112 |
+
# Print debug info
|
113 |
+
print(f"NTP preparation - Text length: {len(text)} characters, {len(tokens)} tokens")
|
114 |
+
print(f"Cut ratio: {cut_ratio}")
|
115 |
+
|
116 |
# Ensure we have enough tokens
|
117 |
if len(tokens) < 5:
|
118 |
return text, "" # Return original if too short
|
119 |
|
120 |
+
# Calculate cutoff point based on the cut ratio
|
|
|
121 |
cutoff = max(3, int(len(tokens) * (1 - cut_ratio)))
|
122 |
cutoff = min(cutoff, len(tokens) - 1) # Ensure there's at least 1 token to predict
|
123 |
|
124 |
+
print(f"Cutoff point: {cutoff} (keeping {cutoff} tokens, cutting {len(tokens) - cutoff} tokens)")
|
125 |
+
|
126 |
# Get the visible part
|
127 |
visible_tokens = tokens[:cutoff]
|
128 |
|
|
|
133 |
visible_text = tokenizer.convert_tokens_to_string(visible_tokens)
|
134 |
hidden_text = tokenizer.convert_tokens_to_string(hidden_tokens)
|
135 |
|
136 |
+
print(f"Visible text length: {len(visible_text)} chars")
|
137 |
+
print(f"Hidden text length: {len(hidden_text)} chars")
|
138 |
+
|
139 |
return visible_text, hidden_text
|
140 |
|
141 |
def get_new_sample(task, mask_ratio=0.15):
|
142 |
"""Get a new text sample based on the task."""
|
143 |
+
global current_sample, masked_text, masked_indices, masked_tokens, original_text, ntp_state, current_task
|
144 |
+
|
145 |
+
# Update current task
|
146 |
+
current_task = task
|
147 |
|
148 |
# Select a random sample
|
149 |
current_sample = random.choice(data_samples)
|
150 |
|
151 |
+
# Print debugging info
|
152 |
+
print(f"Getting new sample for task: {task} with mask ratio: {mask_ratio}")
|
153 |
+
|
154 |
if task == "mlm":
|
155 |
# Prepare MLM sample
|
156 |
masked_text, masked_indices, masked_tokens = prepare_mlm_sample(current_sample, mask_ratio)
|
|
|
395 |
)
|
396 |
|
397 |
with gr.Row():
|
398 |
+
new_button = gr.Button("New Sample", variant="primary")
|
399 |
reset_button = gr.Button("Reset Stats")
|
400 |
|
401 |
# Consolidated input area - only one visible at a time
|
|
|
455 |
outputs=[mlm_instructions, ntp_instructions, answer_input, mask_count]
|
456 |
)
|
457 |
|
458 |
+
# Update the sample text when mask ratio changes (without clicking new sample)
|
459 |
+
def update_on_ratio_change(mask_ratio_pct, task):
|
460 |
+
print(f"Ratio changed to {mask_ratio_pct}%")
|
461 |
+
# Don't generate a new sample here, just update the UI to show the effect of ratio change
|
462 |
+
return f"Current mask/cut ratio: {mask_ratio_pct}%. Click 'New Sample' to apply."
|
463 |
+
|
464 |
+
mask_ratio.change(
|
465 |
+
update_on_ratio_change,
|
466 |
+
inputs=[mask_ratio, task_radio],
|
467 |
+
outputs=[result]
|
468 |
+
)
|
469 |
+
|
470 |
# Update the sample text and also update the mask count
|
471 |
def new_sample_with_count(mask_ratio_pct, task):
|
472 |
+
print(f"Generating new sample with mask ratio: {mask_ratio_pct}% for task: {task}")
|
473 |
ratio = float(mask_ratio_pct) / 100.0
|
474 |
sample = get_new_sample(task, ratio)
|
475 |
mask_count_text = ""
|
|
|
477 |
if task == "mlm":
|
478 |
count = len(masked_tokens)
|
479 |
mask_count_text = f"**Number of [MASK] tokens to guess: {count}**"
|
480 |
+
print(f"Generated MLM sample with {count} masks at ratio {ratio}")
|
481 |
else:
|
482 |
mask_count_text = "**Next Token Prediction mode - guess one token at a time**"
|
483 |
+
print(f"Generated NTP sample with cut ratio {ratio}")
|
484 |
|
485 |
return sample, mask_count_text, ""
|
486 |
|