orionweller commited on
Commit
ea15511
·
verified ·
1 Parent(s): 9b671c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -4
app.py CHANGED
@@ -60,13 +60,21 @@ def prepare_mlm_sample(text, mask_ratio=0.15):
60
  global masked_indices, masked_tokens, original_text
61
 
62
  tokens = tokenizer.tokenize(text)
 
 
63
  # Only mask whole words, not special tokens or punctuation
64
  maskable_indices = [i for i, token in enumerate(tokens)
65
  if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
66
  and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
67
 
 
 
 
68
  # Calculate how many tokens to mask, but ensure at least 1 and at most 8
 
69
  num_to_mask = max(1, min(8, int(len(maskable_indices) * mask_ratio)))
 
 
70
  # Randomly select indices to mask
71
  indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
72
  # Sort indices to ensure they're in order
@@ -101,15 +109,20 @@ def prepare_ntp_sample(text, cut_ratio=0.3):
101
  # Tokenize text to ensure reasonable cutting
102
  tokens = tokenizer.tokenize(text)
103
 
 
 
 
 
104
  # Ensure we have enough tokens
105
  if len(tokens) < 5:
106
  return text, "" # Return original if too short
107
 
108
- # Calculate cutoff point (70% of tokens if cut_ratio is 0.3)
109
- # But make sure we have at least 3 tokens visible and 1 token hidden
110
  cutoff = max(3, int(len(tokens) * (1 - cut_ratio)))
111
  cutoff = min(cutoff, len(tokens) - 1) # Ensure there's at least 1 token to predict
112
 
 
 
113
  # Get the visible part
114
  visible_tokens = tokens[:cutoff]
115
 
@@ -120,15 +133,24 @@ def prepare_ntp_sample(text, cut_ratio=0.3):
120
  visible_text = tokenizer.convert_tokens_to_string(visible_tokens)
121
  hidden_text = tokenizer.convert_tokens_to_string(hidden_tokens)
122
 
 
 
 
123
  return visible_text, hidden_text
124
 
125
  def get_new_sample(task, mask_ratio=0.15):
126
  """Get a new text sample based on the task."""
127
- global current_sample, masked_text, masked_indices, masked_tokens, original_text, ntp_state
 
 
 
128
 
129
  # Select a random sample
130
  current_sample = random.choice(data_samples)
131
 
 
 
 
132
  if task == "mlm":
133
  # Prepare MLM sample
134
  masked_text, masked_indices, masked_tokens = prepare_mlm_sample(current_sample, mask_ratio)
@@ -373,7 +395,7 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
373
  )
374
 
375
  with gr.Row():
376
- new_button = gr.Button("New Sample")
377
  reset_button = gr.Button("Reset Stats")
378
 
379
  # Consolidated input area - only one visible at a time
@@ -433,8 +455,21 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
433
  outputs=[mlm_instructions, ntp_instructions, answer_input, mask_count]
434
  )
435
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  # Update the sample text and also update the mask count
437
  def new_sample_with_count(mask_ratio_pct, task):
 
438
  ratio = float(mask_ratio_pct) / 100.0
439
  sample = get_new_sample(task, ratio)
440
  mask_count_text = ""
@@ -442,8 +477,10 @@ with gr.Blocks(title="MLM and NTP Testing") as demo:
442
  if task == "mlm":
443
  count = len(masked_tokens)
444
  mask_count_text = f"**Number of [MASK] tokens to guess: {count}**"
 
445
  else:
446
  mask_count_text = "**Next Token Prediction mode - guess one token at a time**"
 
447
 
448
  return sample, mask_count_text, ""
449
 
 
60
  global masked_indices, masked_tokens, original_text
61
 
62
  tokens = tokenizer.tokenize(text)
63
+ print(f"Text length: {len(text)} characters, {len(tokens)} tokens")
64
+
65
  # Only mask whole words, not special tokens or punctuation
66
  maskable_indices = [i for i, token in enumerate(tokens)
67
  if not token.startswith("##") and not token.startswith("[") and not token.endswith("]")
68
  and token not in [".", ",", "!", "?", ";", ":", "'", "\"", "-"]]
69
 
70
+ print(f"Maskable indices count: {len(maskable_indices)}")
71
+ print(f"Mask ratio: {mask_ratio}")
72
+
73
  # Calculate how many tokens to mask, but ensure at least 1 and at most 8
74
+ # Use the maskable_indices length with the ratio
75
  num_to_mask = max(1, min(8, int(len(maskable_indices) * mask_ratio)))
76
+ print(f"Number of tokens to mask: {num_to_mask}")
77
+
78
  # Randomly select indices to mask
79
  indices_to_mask = random.sample(maskable_indices, min(num_to_mask, len(maskable_indices)))
80
  # Sort indices to ensure they're in order
 
109
  # Tokenize text to ensure reasonable cutting
110
  tokens = tokenizer.tokenize(text)
111
 
112
+ # Print debug info
113
+ print(f"NTP preparation - Text length: {len(text)} characters, {len(tokens)} tokens")
114
+ print(f"Cut ratio: {cut_ratio}")
115
+
116
  # Ensure we have enough tokens
117
  if len(tokens) < 5:
118
  return text, "" # Return original if too short
119
 
120
+ # Calculate cutoff point based on the cut ratio
 
121
  cutoff = max(3, int(len(tokens) * (1 - cut_ratio)))
122
  cutoff = min(cutoff, len(tokens) - 1) # Ensure there's at least 1 token to predict
123
 
124
+ print(f"Cutoff point: {cutoff} (keeping {cutoff} tokens, cutting {len(tokens) - cutoff} tokens)")
125
+
126
  # Get the visible part
127
  visible_tokens = tokens[:cutoff]
128
 
 
133
  visible_text = tokenizer.convert_tokens_to_string(visible_tokens)
134
  hidden_text = tokenizer.convert_tokens_to_string(hidden_tokens)
135
 
136
+ print(f"Visible text length: {len(visible_text)} chars")
137
+ print(f"Hidden text length: {len(hidden_text)} chars")
138
+
139
  return visible_text, hidden_text
140
 
141
  def get_new_sample(task, mask_ratio=0.15):
142
  """Get a new text sample based on the task."""
143
+ global current_sample, masked_text, masked_indices, masked_tokens, original_text, ntp_state, current_task
144
+
145
+ # Update current task
146
+ current_task = task
147
 
148
  # Select a random sample
149
  current_sample = random.choice(data_samples)
150
 
151
+ # Print debugging info
152
+ print(f"Getting new sample for task: {task} with mask ratio: {mask_ratio}")
153
+
154
  if task == "mlm":
155
  # Prepare MLM sample
156
  masked_text, masked_indices, masked_tokens = prepare_mlm_sample(current_sample, mask_ratio)
 
395
  )
396
 
397
  with gr.Row():
398
+ new_button = gr.Button("New Sample", variant="primary")
399
  reset_button = gr.Button("Reset Stats")
400
 
401
  # Consolidated input area - only one visible at a time
 
455
  outputs=[mlm_instructions, ntp_instructions, answer_input, mask_count]
456
  )
457
 
458
+ # Update the sample text when mask ratio changes (without clicking new sample)
459
+ def update_on_ratio_change(mask_ratio_pct, task):
460
+ print(f"Ratio changed to {mask_ratio_pct}%")
461
+ # Don't generate a new sample here, just update the UI to show the effect of ratio change
462
+ return f"Current mask/cut ratio: {mask_ratio_pct}%. Click 'New Sample' to apply."
463
+
464
+ mask_ratio.change(
465
+ update_on_ratio_change,
466
+ inputs=[mask_ratio, task_radio],
467
+ outputs=[result]
468
+ )
469
+
470
  # Update the sample text and also update the mask count
471
  def new_sample_with_count(mask_ratio_pct, task):
472
+ print(f"Generating new sample with mask ratio: {mask_ratio_pct}% for task: {task}")
473
  ratio = float(mask_ratio_pct) / 100.0
474
  sample = get_new_sample(task, ratio)
475
  mask_count_text = ""
 
477
  if task == "mlm":
478
  count = len(masked_tokens)
479
  mask_count_text = f"**Number of [MASK] tokens to guess: {count}**"
480
+ print(f"Generated MLM sample with {count} masks at ratio {ratio}")
481
  else:
482
  mask_count_text = "**Next Token Prediction mode - guess one token at a time**"
483
+ print(f"Generated NTP sample with cut ratio {ratio}")
484
 
485
  return sample, mask_count_text, ""
486