import base64
import time
from io import BytesIO
from typing import Any
import torch
from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image
torch.backends.cuda.matmul.allow_tf32 = True
class Model:
def __init__(self, **kwargs):
self._model = None
def load(self):
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
self.pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
self.pipe.unet.to(memory_format=torch.channels_last)
self.pipe.to("cuda")
self.pipe.enable_xformers_memory_efficient_attention()
self.refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.pipe.text_encoder_2,
vae=self.pipe.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
self.refiner.to("cuda")
self.refiner.enable_xformers_memory_efficient_attention()
def convert_to_b64(self, image: Image) -> str:
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_b64
def predict(self, model_input: Any) -> Any:
prompt = model_input.pop("prompt")
negative_prompt = model_input.pop("negative_prompt", None)
use_refiner = model_input.pop("use_refiner", True)
num_inference_steps = model_input.pop("num_inference_steps", 30)
denoising_frac = model_input.pop("denoising_frac", 0.8)
end_cfg_frac = model_input.pop("end_cfg_frac", 0.4)
guidance_scale = model_input.pop("guidance_scale", 7.5)
seed = model_input.pop("seed", None)
scheduler = model_input.pop(
"scheduler", None
) # Default: EulerDiscreteScheduler (works pretty well)
if scheduler == "DPM++ 2M":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config
)
elif scheduler == "DPM++ 2M Karras":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config, use_karras_sigmas=True
)
elif scheduler == "DPM++ 2M SDE Karras":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config,
algorithm_type="sde-dpmsolver++",
use_karras_sigmas=True,
)
generator = None
if seed is not None:
torch.manual_seed(seed)
generator = [torch.Generator(device="cuda").manual_seed(seed)]
if not use_refiner:
denoising_frac = 1.0
start_time = time.time()
image = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
end_cfg=end_cfg_frac,
num_inference_steps=num_inference_steps,
denoising_end=denoising_frac,
guidance_scale=guidance_scale,
output_type="latent" if use_refiner else "pil",
).images[0]
scheduler = self.pipe.scheduler
if use_refiner:
self.refiner.scheduler = scheduler
image = self.refiner(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
end_cfg=end_cfg_frac,
num_inference_steps=num_inference_steps,
denoising_start=denoising_frac,
guidance_scale=guidance_scale,
image=image[None, :],
).images[0]
b64_results = self.convert_to_b64(image)
end_time = time.time() - start_time
print(f"Time: {end_time:.2f} seconds")
return {"status": "success", "data": b64_results, "time": end_time}
Building a text-to-image model with SDXL
import base64
import time
from io import BytesIO
from typing import Any
import torch
from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image
torch.backends.cuda.matmul.allow_tf32 = True
class Model:
def __init__(self, **kwargs):
self._model = None
def load(self):
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
self.pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
self.pipe.unet.to(memory_format=torch.channels_last)
self.pipe.to("cuda")
self.pipe.enable_xformers_memory_efficient_attention()
self.refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.pipe.text_encoder_2,
vae=self.pipe.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
self.refiner.to("cuda")
self.refiner.enable_xformers_memory_efficient_attention()
def convert_to_b64(self, image: Image) -> str:
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_b64
def predict(self, model_input: Any) -> Any:
prompt = model_input.pop("prompt")
negative_prompt = model_input.pop("negative_prompt", None)
use_refiner = model_input.pop("use_refiner", True)
num_inference_steps = model_input.pop("num_inference_steps", 30)
denoising_frac = model_input.pop("denoising_frac", 0.8)
end_cfg_frac = model_input.pop("end_cfg_frac", 0.4)
guidance_scale = model_input.pop("guidance_scale", 7.5)
seed = model_input.pop("seed", None)
scheduler = model_input.pop(
"scheduler", None
) # Default: EulerDiscreteScheduler (works pretty well)
if scheduler == "DPM++ 2M":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config
)
elif scheduler == "DPM++ 2M Karras":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config, use_karras_sigmas=True
)
elif scheduler == "DPM++ 2M SDE Karras":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config,
algorithm_type="sde-dpmsolver++",
use_karras_sigmas=True,
)
generator = None
if seed is not None:
torch.manual_seed(seed)
generator = [torch.Generator(device="cuda").manual_seed(seed)]
if not use_refiner:
denoising_frac = 1.0
start_time = time.time()
image = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
end_cfg=end_cfg_frac,
num_inference_steps=num_inference_steps,
denoising_end=denoising_frac,
guidance_scale=guidance_scale,
output_type="latent" if use_refiner else "pil",
).images[0]
scheduler = self.pipe.scheduler
if use_refiner:
self.refiner.scheduler = scheduler
image = self.refiner(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
end_cfg=end_cfg_frac,
num_inference_steps=num_inference_steps,
denoising_start=denoising_frac,
guidance_scale=guidance_scale,
image=image[None, :],
).images[0]
b64_results = self.convert_to_b64(image)
end_time = time.time() - start_time
print(f"Time: {end_time:.2f} seconds")
return {"status": "success", "data": b64_results, "time": end_time}
import base64
import time
from io import BytesIO
from typing import Any
import torch
from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image
torch.backends.cuda.matmul.allow_tf32 = True
Model
class and load functionload
function of the Truss, we implement logic involved in
downloading and setting up the model. For this model, we use the
DiffusionPipeline
class in diffusers
to instantiate our SDXL pipeline,
and configure a number of relevant parameters.
See the diffusers docs for details
on all of these parameters.
class Model:
def __init__(self, **kwargs):
self._model = None
def load(self):
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
self.pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
self.pipe.unet.to(memory_format=torch.channels_last)
self.pipe.to("cuda")
self.pipe.enable_xformers_memory_efficient_attention()
self.refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.pipe.text_encoder_2,
vae=self.pipe.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
self.refiner.to("cuda")
self.refiner.enable_xformers_memory_efficient_attention()
def convert_to_b64(self, image: Image) -> str:
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_b64
predict
function contains the actual inference logic. The steps here are:
scheduler
, are somewhat complicateduse_refiner
is set to True
, we run the refiner model on the output def predict(self, model_input: Any) -> Any:
prompt = model_input.pop("prompt")
negative_prompt = model_input.pop("negative_prompt", None)
use_refiner = model_input.pop("use_refiner", True)
num_inference_steps = model_input.pop("num_inference_steps", 30)
denoising_frac = model_input.pop("denoising_frac", 0.8)
end_cfg_frac = model_input.pop("end_cfg_frac", 0.4)
guidance_scale = model_input.pop("guidance_scale", 7.5)
seed = model_input.pop("seed", None)
scheduler = model_input.pop(
"scheduler", None
) # Default: EulerDiscreteScheduler (works pretty well)
if scheduler == "DPM++ 2M":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config
)
elif scheduler == "DPM++ 2M Karras":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config, use_karras_sigmas=True
)
elif scheduler == "DPM++ 2M SDE Karras":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config,
algorithm_type="sde-dpmsolver++",
use_karras_sigmas=True,
)
generator = None
if seed is not None:
torch.manual_seed(seed)
generator = [torch.Generator(device="cuda").manual_seed(seed)]
if not use_refiner:
denoising_frac = 1.0
start_time = time.time()
image = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
end_cfg=end_cfg_frac,
num_inference_steps=num_inference_steps,
denoising_end=denoising_frac,
guidance_scale=guidance_scale,
output_type="latent" if use_refiner else "pil",
).images[0]
scheduler = self.pipe.scheduler
if use_refiner:
self.refiner.scheduler = scheduler
image = self.refiner(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
end_cfg=end_cfg_frac,
num_inference_steps=num_inference_steps,
denoising_start=denoising_frac,
guidance_scale=guidance_scale,
image=image[None, :],
).images[0]
b64_results = self.convert_to_b64(image)
end_time = time.time() - start_time
print(f"Time: {end_time:.2f} seconds")
return {"status": "success", "data": b64_results, "time": end_time}
environment_variables: {}
external_package_dirs: []
model_metadata:
example_model_input: {"prompt": "A tree in a field under the night sky", "use_refiner": true}
model_name: Stable Diffusion XL
python_version: py39
requirements:
- transformers==4.34.0
- accelerate==0.23.0
- safetensors==0.4.0
- git+https://github.com/basetenlabs/diffusers.git@9a353290b1497023d4745a719ec02c50f680499a
- invisible-watermark>=0.2.0
- xformers==0.0.22
resources:
accelerator: A10G
cpu: 3500m
memory: 20Gi
use_gpu: true
secrets: {}
ffmpeg
and a couple other system
packages.
system_packages:
- ffmpeg
- libsm6
- libxext6
model_cache:
- repo_id: madebyollin/sdxl-vae-fp16-fix
allow_patterns:
- config.json
- diffusion_pytorch_model.safetensors
- repo_id: stabilityai/stable-diffusion-xl-base-1.0
allow_patterns:
- "*.json"
- "*.fp16.safetensors"
- sd_xl_base_1.0.safetensors
- repo_id: stabilityai/stable-diffusion-xl-refiner-1.0
allow_patterns:
- "*.json"
- "*.fp16.safetensors"
- sd_xl_refiner_1.0.safetensors
$ truss push
$ truss predict -d '{"prompt": "A tree in a field under the night sky", "use_refiner": true}'
import base64
import time
from io import BytesIO
from typing import Any
import torch
from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image
torch.backends.cuda.matmul.allow_tf32 = True
class Model:
def __init__(self, **kwargs):
self._model = None
def load(self):
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
self.pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
self.pipe.unet.to(memory_format=torch.channels_last)
self.pipe.to("cuda")
self.pipe.enable_xformers_memory_efficient_attention()
self.refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.pipe.text_encoder_2,
vae=self.pipe.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
self.refiner.to("cuda")
self.refiner.enable_xformers_memory_efficient_attention()
def convert_to_b64(self, image: Image) -> str:
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_b64
def predict(self, model_input: Any) -> Any:
prompt = model_input.pop("prompt")
negative_prompt = model_input.pop("negative_prompt", None)
use_refiner = model_input.pop("use_refiner", True)
num_inference_steps = model_input.pop("num_inference_steps", 30)
denoising_frac = model_input.pop("denoising_frac", 0.8)
end_cfg_frac = model_input.pop("end_cfg_frac", 0.4)
guidance_scale = model_input.pop("guidance_scale", 7.5)
seed = model_input.pop("seed", None)
scheduler = model_input.pop(
"scheduler", None
) # Default: EulerDiscreteScheduler (works pretty well)
if scheduler == "DPM++ 2M":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config
)
elif scheduler == "DPM++ 2M Karras":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config, use_karras_sigmas=True
)
elif scheduler == "DPM++ 2M SDE Karras":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config,
algorithm_type="sde-dpmsolver++",
use_karras_sigmas=True,
)
generator = None
if seed is not None:
torch.manual_seed(seed)
generator = [torch.Generator(device="cuda").manual_seed(seed)]
if not use_refiner:
denoising_frac = 1.0
start_time = time.time()
image = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
end_cfg=end_cfg_frac,
num_inference_steps=num_inference_steps,
denoising_end=denoising_frac,
guidance_scale=guidance_scale,
output_type="latent" if use_refiner else "pil",
).images[0]
scheduler = self.pipe.scheduler
if use_refiner:
self.refiner.scheduler = scheduler
image = self.refiner(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
end_cfg=end_cfg_frac,
num_inference_steps=num_inference_steps,
denoising_start=denoising_frac,
guidance_scale=guidance_scale,
image=image[None, :],
).images[0]
b64_results = self.convert_to_b64(image)
end_time = time.time() - start_time
print(f"Time: {end_time:.2f} seconds")
return {"status": "success", "data": b64_results, "time": end_time}