Bagel-7B-Demo / data /video_utils.py
KingNish's picture
Upload 110 files
e6af450 verified
# Copyright (c) 2023 OpenGVLab
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: MIT
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under MIT, with the full license text
# available at https://github.com/OpenGVLab/InternVL/blob/main/LICENSE.
#
# This modified file is released under the same license.
import io
import os
import random
import re
import numpy as np
import decord
from PIL import Image
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
if sample in ['rand', 'middle']: # uniform sampling
acc_samples = min(num_frames, vlen)
# split the video into `acc_samples` intervals, and sample from each interval.
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
ranges = []
for idx, interv in enumerate(intervals[:-1]):
ranges.append((interv, intervals[idx + 1] - 1))
if sample == 'rand':
try:
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
except:
frame_indices = np.random.permutation(vlen)[:acc_samples]
frame_indices.sort()
frame_indices = list(frame_indices)
elif fix_start is not None:
frame_indices = [x[0] + fix_start for x in ranges]
elif sample == 'middle':
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
else:
raise NotImplementedError
if len(frame_indices) < num_frames: # padded with last frame
padded_frame_indices = [frame_indices[-1]] * num_frames
padded_frame_indices[:len(frame_indices)] = frame_indices
frame_indices = padded_frame_indices
elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps
output_fps = float(sample[3:])
duration = float(vlen) / input_fps
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
frame_indices = np.around(frame_seconds * input_fps).astype(int)
frame_indices = [e for e in frame_indices if e < vlen]
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
frame_indices = frame_indices[:max_num_frames]
else:
raise ValueError
return frame_indices
def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, clip=None, min_num_frames=4):
video_reader = decord.VideoReader(video_path, num_threads=1)
vlen = len(video_reader)
fps = video_reader.get_avg_fps()
duration = vlen / float(fps)
if clip:
start, end = clip
duration = end - start
vlen = int(duration * fps)
start_index = int(start * fps)
t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
frame_indices = get_frame_indices(
t_num_frames, vlen, sample=sample, fix_start=fix_start,
input_fps=fps
)
if clip:
frame_indices = [f + start_index for f in frame_indices]
frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8
frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
return frames
def extract_frame_number(filename):
# Extract the numeric part from the filename using regular expressions
match = re.search(r'_(\d+).jpg$', filename)
return int(match.group(1)) if match else -1
def sort_frames(frame_paths):
# Extract filenames from each path and sort by their numeric part
return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))
def read_frames_folder(video_path, num_frames, sample='rand', fix_start=None, min_num_frames=4):
image_list = sort_frames(list(os.listdir(video_path)))
frames = []
for image in image_list:
fp = os.path.join(video_path, image)
frame = Image.open(fp).convert('RGB')
frames.append(frame)
vlen = len(frames)
t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
if vlen > t_num_frames:
frame_indices = get_frame_indices(
t_num_frames, vlen, sample=sample, fix_start=fix_start
)
frames = [frames[i] for i in frame_indices]
return frames
class FrameSampler:
def __init__(self, max_num_frames=-1, min_num_frames=8, sample='rand'):
self.max_num_frames = max_num_frames
self.min_num_frames = min_num_frames
self.sample = sample
def __call__(self, file_name):
fn = read_frames_folder if file_name.endswith('/') else read_frames_decord
frames = fn(file_name, num_frames=self.max_num_frames, min_num_frames=self.min_num_frames, sample=self.sample)
return frames
def decode_video_byte(video_bytes):
video_stream = io.BytesIO(video_bytes)
vr = decord.VideoReader(video_stream)
return vr
def sample_mp4_frames(mp4_p, n_frames=None, fps=None, return_frame_indices=False, random_sample=False):
if isinstance(mp4_p, str):
vr = decord.VideoReader(mp4_p, num_threads=1)
elif isinstance(mp4_p, decord.video_reader.VideoReader):
vr = mp4_p
video_fps = vr.get_avg_fps() # 获取视频的帧率
video_duration = len(vr) / video_fps
if n_frames is not None:
if random_sample:
frame_indices = sorted(random.sample(range(len(vr)), n_frames))
else:
frame_indices = np.linspace(0, len(vr)-1, n_frames, dtype=int).tolist()
else:
frame_indices = [int(i) for i in np.arange(0, len(vr)-1, video_fps/fps)]
frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
if not return_frame_indices:
return frames, video_duration
else:
return frames, video_duration, frame_indices
def sample_mp4_frames_by_indices(mp4_p, frame_indices: list):
if isinstance(mp4_p, str):
vr = decord.VideoReader(mp4_p, num_threads=1)
elif isinstance(mp4_p, decord.video_reader.VideoReader):
vr = mp4_p
# sample the frames in frame_indices
frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
return frames