# Copyright 2025 Bytedance Ltd. and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 import math import random from PIL import Image import torch from torch.nn.attention.flex_attention import or_masks, and_masks def create_sparse_mask(document_lens, split_lens, attn_modes, device): def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx def full_and_noise_mask(b, h, q_idx, kv_idx): return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0) def remove_noise_mask(b, h, q_idx, kv_idx): return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx]))) def sample_mask(b, h, q_idx, kv_idx): return document_id[q_idx] == document_id[kv_idx] full_and_noise_tmp = [] noise_tmp = [] for i, (length, model) in enumerate(zip(split_lens, attn_modes)): value = i if model in ['full', 'noise'] else -1 full_and_noise_tmp.extend([value] * length) value_noise = i if model == 'noise' else -1 noise_tmp.extend([value_noise] * length) full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device) noise_seq_id = torch.Tensor(noise_tmp).to(device) document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device) return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask) def patchify(image, patch_size): p = patch_size c, h, w = image.shape assert h % p == 0 and w % p == 0 image = image.reshape(c, h // p, p, w // p, p) image = torch.einsum("chpwq->hwpqc", image) image = image.reshape(-1, p**2 * c) return image def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side): num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size coords_h = torch.arange(0, num_patches_h) coords_w = torch.arange(0, num_patches_w) pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten() return pos_ids def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side): num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side) fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten() return pos_ids def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"): """ nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within a sample, where each sample contains multiple splits with different attn modes. nested_attn_modes: whether to use full attn in each split. """ sample_len = sum(split_lens) attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device) csum = 0 for s, attn_mode in zip(split_lens, attn_modes): assert attn_mode in ['causal', 'full', 'noise'] if attn_mode == "causal": attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril() attention_mask[csum:csum + s, :csum] = 1 else: attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s)) attention_mask[csum:csum + s, :csum] = 1 csum += s csum = 0 for s, attn_mode in zip(split_lens, attn_modes): if attn_mode == "noise": attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s)) attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s)) csum += s attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_( ~attention_mask, float("-inf") ) return attention_mask def split_integer_exp_decay(S, ng_sample_decay=1.0): if ng_sample_decay == 1.0: N = random.randint(1, S) else: base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S)) p = [base * math.pow(ng_sample_decay, i) for i in range(S)] N = random.choices(list(range(1, S + 1)), p, k=1)[0] cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S] result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)] return result, cumsum def pil_img2rgb(image): if image.mode == "RGBA" or image.info.get("transparency", None) is not None: image = image.convert("RGBA") white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255)) white.paste(image, mask=image.split()[3]) image = white else: image = image.convert("RGB") return image def add_special_tokens(tokenizer): all_special_tokens = [] for k, v in tokenizer.special_tokens_map.items(): if isinstance(v, str): all_special_tokens.append(v) elif isinstance(v, list): all_special_tokens += v new_tokens = [] if '<|im_start|>' not in all_special_tokens: new_tokens.append('<|im_start|>') if '<|im_end|>' not in all_special_tokens: new_tokens.append('<|im_end|>') if '<|vision_start|>' not in all_special_tokens: new_tokens.append('<|vision_start|>') if '<|vision_end|>' not in all_special_tokens: new_tokens.append('<|vision_end|>') num_new_tokens = tokenizer.add_tokens(new_tokens) bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>') eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>') start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>') end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>') new_token_ids = dict( bos_token_id=bos_token_id, eos_token_id=eos_token_id, start_of_image=start_of_image, end_of_image=end_of_image, ) return tokenizer, new_token_ids, num_new_tokens def len2weight(x, loss_reduction='square'): if x == 0: return x if loss_reduction == 'token': return 1 if loss_reduction == 'sample': return 1 / x if loss_reduction == 'square': return 1 / (x ** 0.5) raise NotImplementedError(loss_reduction)