diff --git a/demo_gradio.py b/demo_gradio.py index 687cf56..971e0fa 100644 --- a/demo_gradio.py +++ b/demo_gradio.py @@ -100,7 +100,7 @@ os.makedirs(outputs_folder, exist_ok=True) @torch.no_grad() -def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache): +def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf): total_latent_sections = (total_second_length * 30) / (latent_window_size * 4) total_latent_sections = int(max(round(total_latent_sections), 1)) @@ -295,7 +295,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4') - save_bcthw_as_mp4(history_pixels, output_filename, fps=30) + save_bcthw_as_mp4(history_pixels, output_filename, fps=30, crf=mp4_crf) print(f'Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}') @@ -315,7 +315,7 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind return -def process(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache): +def process(input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf): global stream assert input_image is not None, 'No input image!' @@ -323,7 +323,7 @@ def process(input_image, prompt, n_prompt, seed, total_second_length, latent_win stream = AsyncStream() - async_run(worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache) + async_run(worker, input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf) output_filename = None @@ -385,13 +385,15 @@ with block: gpu_memory_preservation = gr.Slider(label="GPU Inference Preserved Memory (GB) (larger means slower)", minimum=6, maximum=128, value=6, step=0.1, info="Set this number to a larger value if you encounter OOM. Larger value causes slower speed.") + mp4_crf = gr.Slider(label="MP4 Compression", minimum=0, maximum=100, value=16, step=1, info="Lower means better quality. 0 is uncompressed. Change to 16 if you get black outputs. ") + with gr.Column(): preview_image = gr.Image(label="Next Latents", height=200, visible=False) result_video = gr.Video(label="Finished Frames", autoplay=True, show_share_button=False, height=512, loop=True) gr.Markdown('Note that the ending actions will be generated before the starting actions due to the inverted sampling. If the starting action is not in the video, you just need to wait, and it will be generated later.') progress_desc = gr.Markdown('', elem_classes='no-generating-animation') progress_bar = gr.HTML('', elem_classes='no-generating-animation') - ips = [input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache] + ips = [input_image, prompt, n_prompt, seed, total_second_length, latent_window_size, steps, cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf] start_button.click(fn=process, inputs=ips, outputs=[result_video, preview_image, progress_desc, progress_bar, start_button, end_button]) end_button.click(fn=end_process) diff --git a/diffusers_helper/utils.py b/diffusers_helper/utils.py index 0b2accc..8cd7a0c 100644 --- a/diffusers_helper/utils.py +++ b/diffusers_helper/utils.py @@ -263,7 +263,7 @@ def soft_append_bcthw(history, current, overlap=0): return output.to(history) -def save_bcthw_as_mp4(x, output_filename, fps=10): +def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0): b, c, t, h, w = x.shape per_row = b @@ -276,7 +276,7 @@ def save_bcthw_as_mp4(x, output_filename, fps=10): x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5 x = x.detach().cpu().to(torch.uint8) x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row) - torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': '0'}) + torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))}) return x