Spaces:
Paused
Paused
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler | |
import torch | |
class ImageGenerator: | |
def __init__(self, model_id="stabilityai/stable-diffusion-2-1-base", device="cuda"): | |
""" | |
Initialize the image generator with a specific model. | |
Args: | |
model_id (str): The model identifier for the stable diffusion model. | |
Default is "stabilityai/stable-diffusion-2-1-base". | |
device (str): The device to run the model on, either "cuda" or "cpu". | |
Default is "cuda". | |
""" | |
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") | |
self.pipe = StableDiffusionPipeline.from_pretrained( | |
model_id, | |
scheduler=scheduler, | |
torch_dtype=torch.float16 | |
) | |
self.pipe = self.pipe.to(device) | |
self.positive_prompt = "simple, icon" | |
self.negative_prompt = "3d, blurry, complex geometry, realistic" | |
def generate(self, prompt, negative_prompt=None, output_path=None, num_images=1, num_inference_steps=50): | |
""" | |
Generate an image based on the provided prompt. | |
Args: | |
prompt (str): The text description to generate an image from. | |
negative_prompt (str, optional): Elements to avoid in the generated image. | |
If None, uses the default negative prompt. | |
output_path (str, optional): Path to save the generated image. | |
If None, the image is not saved to disk. | |
num_images (int, optional): Number of images to generate. | |
Returns: | |
list[PIL.Image.Image]: The generated images. | |
""" | |
prompt = f"{prompt}, {self.positive_prompt}" | |
if negative_prompt is None: | |
negative_prompt = self.negative_prompt | |
images = self.pipe( | |
prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=50, | |
num_images_per_prompt=num_images | |
).images | |
if output_path: | |
for i, image in enumerate(images): | |
image.save(f".cache/{output_path.replace('.png', f'_{i}.png')}") | |
return images | |
# Example usage | |
if __name__ == "__main__": | |
generator = ImageGenerator() | |
import time | |
start_time = time.time() | |
image = generator.generate( | |
prompt="magenta trapezoids layered on a transluscent silver sheet", | |
output_path="sheet.png", | |
num_images=4 | |
) | |
end_time = time.time() | |
print(f"Time taken: {end_time - start_time} seconds") | |