Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,745 Bytes
e6af450 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import io
import random
from PIL import Image, ImageFile, PngImagePlugin
from .interleave_t2i_dataset import InterleavedBaseIterableDataset, ParquetStandardIterableDataset
from ..data_utils import pil_img2rgb
Image.MAX_IMAGE_PIXELS = 200000000
ImageFile.LOAD_TRUNCATED_IMAGES = True
MaximumDecompressedSize = 1024
MegaByte = 2 ** 20
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
class UnifiedEditIterableDataset(InterleavedBaseIterableDataset, ParquetStandardIterableDataset):
def parse_row(self, row):
image_num = len(row["image_list"])
# randomly choose start and end, return [0, 1] when only two images
start_idx = random.choice(range(image_num - 1))
max_end = min(start_idx + 3, image_num)
end_idx = random.choice(range(start_idx + 1, max_end))
data = self._init_data()
data = self._add_image(
data,
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][start_idx]))),
need_loss=False,
need_vae=True,
need_vit=True,
)
if end_idx - start_idx > 1 and random.random() < 0.5: # concat multiple insturction
if end_idx == image_num - 1:
end_idx -= 1
instruction = ""
for idx in range(start_idx + 1, end_idx + 1):
instruction += random.choice(row["instruction_list"][idx-1]) + ". "
data = self._add_text(data, instruction.rstrip(), need_loss=False)
data = self._add_image(
data,
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][end_idx]))),
need_loss=True,
need_vae=False,
need_vit=False,
)
else:
for idx in range(start_idx + 1, end_idx + 1):
instruction = random.choice(row["instruction_list"][idx-1])
data = self._add_text(data, instruction, need_loss=False)
if idx != end_idx:
data = self._add_image(
data,
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
need_loss=True,
need_vae=True,
need_vit=True,
)
else:
data = self._add_image(
data,
pil_img2rgb(Image.open(io.BytesIO(row["image_list"][idx]))),
need_loss=True,
need_vae=False,
need_vit=False,
)
return data
|