Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2025 Bytedance Ltd. and/or its affiliates. | |
# SPDX-License-Identifier: Apache-2.0 | |
import math | |
import random | |
from PIL import Image | |
import torch | |
from torch.nn.attention.flex_attention import or_masks, and_masks | |
def create_sparse_mask(document_lens, split_lens, attn_modes, device): | |
def causal_mask(b, h, q_idx, kv_idx): | |
return q_idx >= kv_idx | |
def full_and_noise_mask(b, h, q_idx, kv_idx): | |
return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0) | |
def remove_noise_mask(b, h, q_idx, kv_idx): | |
return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx]))) | |
def sample_mask(b, h, q_idx, kv_idx): | |
return document_id[q_idx] == document_id[kv_idx] | |
full_and_noise_tmp = [] | |
noise_tmp = [] | |
for i, (length, model) in enumerate(zip(split_lens, attn_modes)): | |
value = i if model in ['full', 'noise'] else -1 | |
full_and_noise_tmp.extend([value] * length) | |
value_noise = i if model == 'noise' else -1 | |
noise_tmp.extend([value_noise] * length) | |
full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device) | |
noise_seq_id = torch.Tensor(noise_tmp).to(device) | |
document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device) | |
return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask) | |
def patchify(image, patch_size): | |
p = patch_size | |
c, h, w = image.shape | |
assert h % p == 0 and w % p == 0 | |
image = image.reshape(c, h // p, p, w // p, p) | |
image = torch.einsum("chpwq->hwpqc", image) | |
image = image.reshape(-1, p**2 * c) | |
return image | |
def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side): | |
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size | |
coords_h = torch.arange(0, num_patches_h) | |
coords_w = torch.arange(0, num_patches_w) | |
pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten() | |
return pos_ids | |
def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side): | |
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size | |
boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side) | |
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h) | |
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w) | |
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) | |
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) | |
pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten() | |
return pos_ids | |
def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"): | |
""" | |
nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within | |
a sample, where each sample contains multiple splits with different attn modes. | |
nested_attn_modes: whether to use full attn in each split. | |
""" | |
sample_len = sum(split_lens) | |
attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device) | |
csum = 0 | |
for s, attn_mode in zip(split_lens, attn_modes): | |
assert attn_mode in ['causal', 'full', 'noise'] | |
if attn_mode == "causal": | |
attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril() | |
attention_mask[csum:csum + s, :csum] = 1 | |
else: | |
attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s)) | |
attention_mask[csum:csum + s, :csum] = 1 | |
csum += s | |
csum = 0 | |
for s, attn_mode in zip(split_lens, attn_modes): | |
if attn_mode == "noise": | |
attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s)) | |
attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s)) | |
csum += s | |
attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_( | |
~attention_mask, float("-inf") | |
) | |
return attention_mask | |
def split_integer_exp_decay(S, ng_sample_decay=1.0): | |
if ng_sample_decay == 1.0: | |
N = random.randint(1, S) | |
else: | |
base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S)) | |
p = [base * math.pow(ng_sample_decay, i) for i in range(S)] | |
N = random.choices(list(range(1, S + 1)), p, k=1)[0] | |
cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S] | |
result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)] | |
return result, cumsum | |
def pil_img2rgb(image): | |
if image.mode == "RGBA" or image.info.get("transparency", None) is not None: | |
image = image.convert("RGBA") | |
white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255)) | |
white.paste(image, mask=image.split()[3]) | |
image = white | |
else: | |
image = image.convert("RGB") | |
return image | |
def add_special_tokens(tokenizer): | |
all_special_tokens = [] | |
for k, v in tokenizer.special_tokens_map.items(): | |
if isinstance(v, str): | |
all_special_tokens.append(v) | |
elif isinstance(v, list): | |
all_special_tokens += v | |
new_tokens = [] | |
if '<|im_start|>' not in all_special_tokens: | |
new_tokens.append('<|im_start|>') | |
if '<|im_end|>' not in all_special_tokens: | |
new_tokens.append('<|im_end|>') | |
if '<|vision_start|>' not in all_special_tokens: | |
new_tokens.append('<|vision_start|>') | |
if '<|vision_end|>' not in all_special_tokens: | |
new_tokens.append('<|vision_end|>') | |
num_new_tokens = tokenizer.add_tokens(new_tokens) | |
bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>') | |
eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>') | |
start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>') | |
end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>') | |
new_token_ids = dict( | |
bos_token_id=bos_token_id, | |
eos_token_id=eos_token_id, | |
start_of_image=start_of_image, | |
end_of_image=end_of_image, | |
) | |
return tokenizer, new_token_ids, num_new_tokens | |
def len2weight(x, loss_reduction='square'): | |
if x == 0: | |
return x | |
if loss_reduction == 'token': | |
return 1 | |
if loss_reduction == 'sample': | |
return 1 / x | |
if loss_reduction == 'square': | |
return 1 / (x ** 0.5) | |
raise NotImplementedError(loss_reduction) | |