diff --git a/Dockerfile.sycl b/Dockerfile.sycl new file mode 100644 index 00000000..1b855d6e --- /dev/null +++ b/Dockerfile.sycl @@ -0,0 +1,19 @@ +ARG SYCL_VERSION=2025.1.0-0 + +FROM intel/oneapi-basekit:${SYCL_VERSION}-devel-ubuntu24.04 AS build + +RUN apt-get update && apt-get install -y cmake + +WORKDIR /sd.cpp + +COPY . . + +RUN mkdir build && cd build && \ + cmake .. -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DSD_SYCL=ON -DCMAKE_BUILD_TYPE=Release && \ + cmake --build . --config Release -j$(nproc) + +FROM intel/oneapi-basekit:${SYCL_VERSION}-devel-ubuntu24.04 AS runtime + +COPY --from=build /sd.cpp/build/bin/sd /sd + +ENTRYPOINT [ "/sd" ] diff --git a/README.md b/README.md index a4585be0..451388aa 100644 --- a/README.md +++ b/README.md @@ -60,14 +60,6 @@ API and command-line option may change frequently.*** - Windows - Android (via Termux, [Local Diffusion](https://github.com/rmatif/Local-Diffusion)) -### TODO - -- [ ] More sampling methods -- [ ] Make inference faster - - The current implementation of ggml_conv_2d is slow and has high memory usage -- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d) -- [ ] Implement Inpainting support - ## Usage For most users, you can download the built executable program from the latest [release](https://github.com/leejet/stable-diffusion.cpp/releases/latest). @@ -307,9 +299,6 @@ arguments: --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) --control-net [CONTROL_PATH] path to control net model --embd-dir [EMBEDDING_PATH] path to embeddings - --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings - --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir - --normalize-input normalize PHOTOMAKER input id images --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now --upscale-repeats Run the ESRGAN upscaler this many times (default 1) --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K) @@ -321,6 +310,9 @@ arguments: -i, --end-img [IMAGE] path to the end image, required by flf2v --control-image [IMAGE] path to image condition, control net -r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times) + --control-video [PATH] path to control video frames, It must be a directory path. + The video frames inside should be stored as images in lexicographical (character) order + For example, if the control video path is `frames`, the directory contain images such as 00.png, 01.png, 鈥?etc. --increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1). -o, --output OUTPUT path to write result image to (default: ./output.png) -p, --prompt [PROMPT] the prompt to render @@ -334,9 +326,9 @@ arguments: --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9]) --skip-layer-start START SLG enabling point: (default: 0.01) --skip-layer-end END SLG disabling point: (default: 0.2) - --scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete) + --scheduler {discrete, karras, exponential, ays, gits, smoothstep} Denoiser sigma scheduler (default: discrete) --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd} - sampling method (default: "euler_a") + sampling method (default: "euler" for Flux/SD3/Wan, "euler_a" otherwise) --steps STEPS number of sample steps (default: 20) --high-noise-cfg-scale SCALE (high noise) unconditional guidance scale: (default: 7.0) --high-noise-img-cfg-scale SCALE (high noise) image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale) @@ -347,13 +339,12 @@ arguments: --high-noise-skip-layers LAYERS (high noise) Layers to skip for SLG steps: (default: [7,8,9]) --high-noise-skip-layer-start (high noise) SLG enabling point: (default: 0.01) --high-noise-skip-layer-end END (high noise) SLG disabling point: (default: 0.2) - --high-noise-scheduler {discrete, karras, exponential, ays, gits} Denoiser sigma scheduler (default: discrete) + --high-noise-scheduler {discrete, karras, exponential, ays, gits, smoothstep} Denoiser sigma scheduler (default: discrete) --high-noise-sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd} (high noise) sampling method (default: "euler_a") --high-noise-steps STEPS (high noise) number of sample steps (default: -1 = auto) SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END]) --strength STRENGTH strength for noising/unnoising (default: 0.75) - --style-ratio STYLE-RATIO strength for keeping input identity (default: 20) --control-strength STRENGTH strength to apply Control Net (default: 0.9) 1.0 corresponds to full destruction of information in init image -H, --height H image height, in pixel space (default: 512) @@ -364,6 +355,9 @@ arguments: --clip-skip N ignore last_dot_pos layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1) <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x --vae-tiling process vae in tiles to reduce memory usage + --vae-tile-size [X]x[Y] tile size for vae tiling (default: 32x32) + --vae-relative-tile-size [X]x[Y] relative tile size for vae tiling, in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size) + --vae-tile-overlap OVERLAP tile overlap for vae tiling, in fraction of tile size (default: 0.5) --vae-on-cpu keep vae in cpu (for low vram) --clip-on-cpu keep clip in cpu (for low vram) --diffusion-fa use flash attention in the diffusion model (for low vram) @@ -384,6 +378,12 @@ arguments: --moe-boundary BOUNDARY timestep boundary for Wan2.2 MoE model. (default: 0.875) only enabled if `--high-noise-steps` is set to -1 --flow-shift SHIFT shift value for Flow models like SD3.x or WAN (default: auto) + --vace-strength wan vace strength + --photo-maker path to PHOTOMAKER model + --pm-id-images-dir [DIR] path to PHOTOMAKER input id images dir + --pm-id-embed-path [PATH] path to PHOTOMAKER v2 id embed + --pm-style-strength strength for keeping PHOTOMAKER input identity (default: 20) + --normalize-input normalize PHOTOMAKER input id images -v, --verbose print extra info ``` @@ -393,9 +393,9 @@ arguments: ./bin/sd -m ../models/sd-v1-4.ckpt -p "a lovely cat" # ./bin/sd -m ../models/v1-5-pruned-emaonly.safetensors -p "a lovely cat" # ./bin/sd -m ../models/sd_xl_base_1.0.safetensors --vae ../models/sdxl_vae-fp16-fix.safetensors -H 1024 -W 1024 -p "a lovely cat" -v -# ./bin/sd -m ../models/sd3_medium_incl_clips_t5xxlfp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable Diffusion CPP\"' --cfg-scale 4.5 --sampling-method euler -v -# ./bin/sd --diffusion-model ../models/flux1-dev-q3_k.gguf --vae ../models/ae.sft --clip_l ../models/clip_l.safetensors --t5xxl ../models/t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v -# ./bin/sd -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v +# ./bin/sd -m ../models/sd3_medium_incl_clips_t5xxlfp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable Diffusion CPP\"' --cfg-scale 4.5 --sampling-method euler -v --clip-on-cpu +# ./bin/sd --diffusion-model ../models/flux1-dev-q3_k.gguf --vae ../models/ae.sft --clip_l ../models/clip_l.safetensors --t5xxl ../models/t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v --clip-on-cpu +# ./bin/sd -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v --clip-on-cpu ``` Using formats of different precisions will yield results of varying quality. diff --git a/assets/wan/Wan2.1_1.3B_vace_r2v.mp4 b/assets/wan/Wan2.1_1.3B_vace_r2v.mp4 new file mode 100644 index 00000000..05f6cfa2 Binary files /dev/null and b/assets/wan/Wan2.1_1.3B_vace_r2v.mp4 differ diff --git a/assets/wan/Wan2.1_1.3B_vace_t2v.mp4 b/assets/wan/Wan2.1_1.3B_vace_t2v.mp4 new file mode 100644 index 00000000..73862e84 Binary files /dev/null and b/assets/wan/Wan2.1_1.3B_vace_t2v.mp4 differ diff --git a/assets/wan/Wan2.1_1.3B_vace_v2v.mp4 b/assets/wan/Wan2.1_1.3B_vace_v2v.mp4 new file mode 100644 index 00000000..2cc4c0a9 Binary files /dev/null and b/assets/wan/Wan2.1_1.3B_vace_v2v.mp4 differ diff --git a/assets/wan/Wan2.1_14B_vace_r2v.mp4 b/assets/wan/Wan2.1_14B_vace_r2v.mp4 new file mode 100644 index 00000000..686371fb Binary files /dev/null and b/assets/wan/Wan2.1_14B_vace_r2v.mp4 differ diff --git a/assets/wan/Wan2.1_14B_vace_t2v.mp4 b/assets/wan/Wan2.1_14B_vace_t2v.mp4 new file mode 100644 index 00000000..cebe8f97 Binary files /dev/null and b/assets/wan/Wan2.1_14B_vace_t2v.mp4 differ diff --git a/assets/wan/Wan2.1_14B_vace_v2v.mp4 b/assets/wan/Wan2.1_14B_vace_v2v.mp4 new file mode 100644 index 00000000..95f30d45 Binary files /dev/null and b/assets/wan/Wan2.1_14B_vace_v2v.mp4 differ diff --git a/clip.hpp b/clip.hpp index f92c9c2f..bde8a78a 100644 --- a/clip.hpp +++ b/clip.hpp @@ -548,9 +548,15 @@ class CLIPEmbeddings : public GGMLBlock { int64_t embed_dim; int64_t vocab_size; int64_t num_positions; + bool force_clip_f32; void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { - enum ggml_type token_wtype = GGML_TYPE_F32; + enum ggml_type token_wtype = GGML_TYPE_F32; + if (!force_clip_f32) { + auto tensor_type = tensor_types.find(prefix + "token_embedding.weight"); + if (tensor_type != tensor_types.end()) + token_wtype = tensor_type->second; + } enum ggml_type position_wtype = GGML_TYPE_F32; params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size); @@ -560,10 +566,12 @@ class CLIPEmbeddings : public GGMLBlock { public: CLIPEmbeddings(int64_t embed_dim, int64_t vocab_size = 49408, - int64_t num_positions = 77) + int64_t num_positions = 77, + bool force_clip_f32 = false) : embed_dim(embed_dim), vocab_size(vocab_size), - num_positions(num_positions) { + num_positions(num_positions), + force_clip_f32(force_clip_f32) { } struct ggml_tensor* get_token_embed_weight() { @@ -678,12 +686,11 @@ class CLIPTextModel : public GGMLBlock { int32_t n_head = 12; int32_t n_layer = 12; // num_hidden_layers int32_t projection_dim = 1280; // only for OPEN_CLIP_VIT_BIGG_14 - int32_t clip_skip = -1; bool with_final_ln = true; CLIPTextModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14, bool with_final_ln = true, - int clip_skip_value = -1) + bool force_clip_f32 = false) : version(version), with_final_ln(with_final_ln) { if (version == OPEN_CLIP_VIT_H_14) { hidden_size = 1024; @@ -696,20 +703,12 @@ class CLIPTextModel : public GGMLBlock { n_head = 20; n_layer = 32; } - set_clip_skip(clip_skip_value); - blocks["embeddings"] = std::shared_ptr(new CLIPEmbeddings(hidden_size, vocab_size, n_token)); + blocks["embeddings"] = std::shared_ptr(new CLIPEmbeddings(hidden_size, vocab_size, n_token, force_clip_f32)); blocks["encoder"] = std::shared_ptr(new CLIPEncoder(n_layer, hidden_size, n_head, intermediate_size)); blocks["final_layer_norm"] = std::shared_ptr(new LayerNorm(hidden_size)); } - void set_clip_skip(int skip) { - if (skip <= 0) { - skip = -1; - } - clip_skip = skip; - } - struct ggml_tensor* get_token_embed_weight() { auto embeddings = std::dynamic_pointer_cast(blocks["embeddings"]); return embeddings->get_token_embed_weight(); @@ -720,7 +719,8 @@ class CLIPTextModel : public GGMLBlock { struct ggml_tensor* input_ids, struct ggml_tensor* tkn_embeddings, size_t max_token_idx = 0, - bool return_pooled = false) { + bool return_pooled = false, + int clip_skip = -1) { // input_ids: [N, n_token] auto embeddings = std::dynamic_pointer_cast(blocks["embeddings"]); auto encoder = std::dynamic_pointer_cast(blocks["encoder"]); @@ -889,8 +889,8 @@ struct CLIPTextModelRunner : public GGMLRunner { const std::string prefix, CLIPVersion version = OPENAI_CLIP_VIT_L_14, bool with_final_ln = true, - int clip_skip_value = -1) - : GGMLRunner(backend, offload_params_to_cpu), model(version, with_final_ln, clip_skip_value) { + bool force_clip_f32 = false) + : GGMLRunner(backend, offload_params_to_cpu), model(version, with_final_ln, force_clip_f32) { model.init(params_ctx, tensor_types, prefix); } @@ -898,10 +898,6 @@ struct CLIPTextModelRunner : public GGMLRunner { return "clip"; } - void set_clip_skip(int clip_skip) { - model.set_clip_skip(clip_skip); - } - void get_param_tensors(std::map& tensors, const std::string prefix) { model.get_param_tensors(tensors, prefix); } @@ -911,7 +907,8 @@ struct CLIPTextModelRunner : public GGMLRunner { struct ggml_tensor* input_ids, struct ggml_tensor* embeddings, size_t max_token_idx = 0, - bool return_pooled = false) { + bool return_pooled = false, + int clip_skip = -1) { size_t N = input_ids->ne[1]; size_t n_token = input_ids->ne[0]; if (input_ids->ne[0] > model.n_token) { @@ -919,14 +916,15 @@ struct CLIPTextModelRunner : public GGMLRunner { input_ids = ggml_reshape_2d(ctx, input_ids, model.n_token, input_ids->ne[0] / model.n_token); } - return model.forward(ctx, backend, input_ids, embeddings, max_token_idx, return_pooled); + return model.forward(ctx, backend, input_ids, embeddings, max_token_idx, return_pooled, clip_skip); } struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, int num_custom_embeddings = 0, void* custom_embeddings_data = NULL, size_t max_token_idx = 0, - bool return_pooled = false) { + bool return_pooled = false, + int clip_skip = -1) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); input_ids = to_backend(input_ids); @@ -945,7 +943,7 @@ struct CLIPTextModelRunner : public GGMLRunner { embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1); } - struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, embeddings, max_token_idx, return_pooled); + struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, embeddings, max_token_idx, return_pooled, clip_skip); ggml_build_forward_expand(gf, hidden_states); @@ -958,10 +956,11 @@ struct CLIPTextModelRunner : public GGMLRunner { void* custom_embeddings_data, size_t max_token_idx, bool return_pooled, + int clip_skip, ggml_tensor** output, ggml_context* output_ctx = NULL) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(input_ids, num_custom_embeddings, custom_embeddings_data, max_token_idx, return_pooled); + return build_graph(input_ids, num_custom_embeddings, custom_embeddings_data, max_token_idx, return_pooled, clip_skip); }; GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } diff --git a/conditioner.hpp b/conditioner.hpp index cfd2b4ca..b1dc7698 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -61,30 +61,16 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { const String2GGMLType& tensor_types, const std::string& embd_dir, SDVersion version = VERSION_SD1, - PMVersion pv = PM_VERSION_1, - int clip_skip = -1) + PMVersion pv = PM_VERSION_1) : version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) { + bool force_clip_f32 = embd_dir.size() > 0; if (sd_version_is_sd1(version)) { - text_model = std::make_shared(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14); + text_model = std::make_shared(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, true, force_clip_f32); } else if (sd_version_is_sd2(version)) { - text_model = std::make_shared(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14); + text_model = std::make_shared(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, true, force_clip_f32); } else if (sd_version_is_sdxl(version)) { - text_model = std::make_shared(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, false); - text_model2 = std::make_shared(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false); - } - set_clip_skip(clip_skip); - } - - void set_clip_skip(int clip_skip) { - if (clip_skip <= 0) { - clip_skip = 1; - if (sd_version_is_sd2(version) || sd_version_is_sdxl(version)) { - clip_skip = 2; - } - } - text_model->set_clip_skip(clip_skip); - if (sd_version_is_sdxl(version)) { - text_model2->set_clip_skip(clip_skip); + text_model = std::make_shared(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, false, force_clip_f32); + text_model2 = std::make_shared(backend, offload_params_to_cpu, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false, force_clip_f32); } } @@ -129,7 +115,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { return true; } struct ggml_init_params params; - params.mem_size = 10 * 1024 * 1024; // max for custom embeddings 10 MB + params.mem_size = 100 * 1024 * 1024; // max for custom embeddings 100 MB params.mem_buffer = NULL; params.no_alloc = false; struct ggml_context* embd_ctx = ggml_init(params); @@ -412,7 +398,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { int height, int adm_in_channels = -1, bool zero_out_masked = false) { - set_clip_skip(clip_skip); int64_t t0 = ggml_time_ms(); struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size] struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, hidden_size] or [n_token, hidden_size + hidden_size2] @@ -421,6 +406,10 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { struct ggml_tensor* pooled = NULL; std::vector hidden_states_vec; + if (clip_skip <= 0) { + clip_skip = (sd_version_is_sd2(version) || sd_version_is_sdxl(version)) ? 2 : 1; + } + size_t chunk_len = 77; size_t chunk_count = tokens.size() / chunk_len; for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) { @@ -455,6 +444,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { token_embed_custom.data(), max_token_idx, false, + clip_skip, &chunk_hidden_states1, work_ctx); if (sd_version_is_sdxl(version)) { @@ -464,6 +454,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { token_embed_custom.data(), max_token_idx, false, + clip_skip, &chunk_hidden_states2, work_ctx); // concat chunk_hidden_states = ggml_tensor_concat(work_ctx, chunk_hidden_states1, chunk_hidden_states2, 0); @@ -475,6 +466,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { token_embed_custom.data(), max_token_idx, true, + clip_skip, &pooled, work_ctx); } @@ -669,21 +661,11 @@ struct SD3CLIPEmbedder : public Conditioner { SD3CLIPEmbedder(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - int clip_skip = -1) + const String2GGMLType& tensor_types = {}) : clip_g_tokenizer(0) { clip_l = std::make_shared(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false); clip_g = std::make_shared(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false); t5 = std::make_shared(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer"); - set_clip_skip(clip_skip); - } - - void set_clip_skip(int clip_skip) { - if (clip_skip <= 0) { - clip_skip = 2; - } - clip_l->set_clip_skip(clip_skip); - clip_g->set_clip_skip(clip_skip); } void get_param_tensors(std::map& tensors) { @@ -780,7 +762,6 @@ struct SD3CLIPEmbedder : public Conditioner { std::vector, std::vector>> token_and_weights, int clip_skip, bool zero_out_masked = false) { - set_clip_skip(clip_skip); auto& clip_l_tokens = token_and_weights[0].first; auto& clip_l_weights = token_and_weights[0].second; auto& clip_g_tokens = token_and_weights[1].first; @@ -788,6 +769,10 @@ struct SD3CLIPEmbedder : public Conditioner { auto& t5_tokens = token_and_weights[2].first; auto& t5_weights = token_and_weights[2].second; + if (clip_skip <= 0) { + clip_skip = 2; + } + int64_t t0 = ggml_time_ms(); struct ggml_tensor* hidden_states = NULL; // [N, n_token*2, 4096] struct ggml_tensor* chunk_hidden_states = NULL; // [n_token*2, 4096] @@ -818,6 +803,7 @@ struct SD3CLIPEmbedder : public Conditioner { NULL, max_token_idx, false, + clip_skip, &chunk_hidden_states_l, work_ctx); { @@ -845,6 +831,7 @@ struct SD3CLIPEmbedder : public Conditioner { NULL, max_token_idx, true, + clip_skip, &pooled_l, work_ctx); } @@ -866,6 +853,7 @@ struct SD3CLIPEmbedder : public Conditioner { NULL, max_token_idx, false, + clip_skip, &chunk_hidden_states_g, work_ctx); @@ -894,6 +882,7 @@ struct SD3CLIPEmbedder : public Conditioner { NULL, max_token_idx, true, + clip_skip, &pooled_g, work_ctx); } @@ -1017,18 +1006,9 @@ struct FluxCLIPEmbedder : public Conditioner { FluxCLIPEmbedder(ggml_backend_t backend, bool offload_params_to_cpu, - const String2GGMLType& tensor_types = {}, - int clip_skip = -1) { + const String2GGMLType& tensor_types = {}) { clip_l = std::make_shared(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true); t5 = std::make_shared(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer"); - set_clip_skip(clip_skip); - } - - void set_clip_skip(int clip_skip) { - if (clip_skip <= 0) { - clip_skip = 2; - } - clip_l->set_clip_skip(clip_skip); } void get_param_tensors(std::map& tensors) { @@ -1109,12 +1089,15 @@ struct FluxCLIPEmbedder : public Conditioner { std::vector, std::vector>> token_and_weights, int clip_skip, bool zero_out_masked = false) { - set_clip_skip(clip_skip); auto& clip_l_tokens = token_and_weights[0].first; auto& clip_l_weights = token_and_weights[0].second; auto& t5_tokens = token_and_weights[1].first; auto& t5_weights = token_and_weights[1].second; + if (clip_skip <= 0) { + clip_skip = 2; + } + int64_t t0 = ggml_time_ms(); struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096] struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096] @@ -1143,6 +1126,7 @@ struct FluxCLIPEmbedder : public Conditioner { NULL, max_token_idx, true, + clip_skip, &pooled, work_ctx); } @@ -1241,7 +1225,6 @@ struct T5CLIPEmbedder : public Conditioner { T5CLIPEmbedder(ggml_backend_t backend, bool offload_params_to_cpu, const String2GGMLType& tensor_types = {}, - int clip_skip = -1, bool use_mask = false, int mask_pad = 1, bool is_umt5 = false) @@ -1249,9 +1232,6 @@ struct T5CLIPEmbedder : public Conditioner { t5 = std::make_shared(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer", is_umt5); } - void set_clip_skip(int clip_skip) { - } - void get_param_tensors(std::map& tensors) { t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer"); } diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 138ca67a..92d3da5a 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -6,23 +6,29 @@ #include "unet.hpp" #include "wan.hpp" +struct DiffusionParams { + struct ggml_tensor* x = NULL; + struct ggml_tensor* timesteps = NULL; + struct ggml_tensor* context = NULL; + struct ggml_tensor* c_concat = NULL; + struct ggml_tensor* y = NULL; + struct ggml_tensor* guidance = NULL; + std::vector ref_latents = {}; + bool increase_ref_index = false; + int num_video_frames = -1; + std::vector controls = {}; + float control_strength = 0.f; + struct ggml_tensor* vace_context = NULL; + float vace_strength = 1.f; + std::vector skip_layers = {}; +}; + struct DiffusionModel { virtual std::string get_desc() = 0; virtual void compute(int n_threads, - struct ggml_tensor* x, - struct ggml_tensor* timesteps, - struct ggml_tensor* context, - struct ggml_tensor* c_concat, - struct ggml_tensor* y, - struct ggml_tensor* guidance, - std::vector ref_latents = {}, - bool increase_ref_index = false, - int num_video_frames = -1, - std::vector controls = {}, - float control_strength = 0.f, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) = 0; + DiffusionParams diffusion_params, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL) = 0; virtual void alloc_params_buffer() = 0; virtual void free_params_buffer() = 0; virtual void free_compute_buffer() = 0; @@ -71,22 +77,18 @@ struct UNetModel : public DiffusionModel { } void compute(int n_threads, - struct ggml_tensor* x, - struct ggml_tensor* timesteps, - struct ggml_tensor* context, - struct ggml_tensor* c_concat, - struct ggml_tensor* y, - struct ggml_tensor* guidance, - std::vector ref_latents = {}, - bool increase_ref_index = false, - int num_video_frames = -1, - std::vector controls = {}, - float control_strength = 0.f, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) { - (void)skip_layers; // SLG doesn't work with UNet models - return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx); + DiffusionParams diffusion_params, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL) { + return unet.compute(n_threads, + diffusion_params.x, + diffusion_params.timesteps, + diffusion_params.context, + diffusion_params.c_concat, + diffusion_params.y, + diffusion_params.num_video_frames, + diffusion_params.controls, + diffusion_params.control_strength, output, output_ctx); } }; @@ -129,21 +131,17 @@ struct MMDiTModel : public DiffusionModel { } void compute(int n_threads, - struct ggml_tensor* x, - struct ggml_tensor* timesteps, - struct ggml_tensor* context, - struct ggml_tensor* c_concat, - struct ggml_tensor* y, - struct ggml_tensor* guidance, - std::vector ref_latents = {}, - bool increase_ref_index = false, - int num_video_frames = -1, - std::vector controls = {}, - float control_strength = 0.f, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) { - return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers); + DiffusionParams diffusion_params, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL) { + return mmdit.compute(n_threads, + diffusion_params.x, + diffusion_params.timesteps, + diffusion_params.context, + diffusion_params.y, + output, + output_ctx, + diffusion_params.skip_layers); } }; @@ -188,21 +186,21 @@ struct FluxModel : public DiffusionModel { } void compute(int n_threads, - struct ggml_tensor* x, - struct ggml_tensor* timesteps, - struct ggml_tensor* context, - struct ggml_tensor* c_concat, - struct ggml_tensor* y, - struct ggml_tensor* guidance, - std::vector ref_latents = {}, - bool increase_ref_index = false, - int num_video_frames = -1, - std::vector controls = {}, - float control_strength = 0.f, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) { - return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, ref_latents, increase_ref_index, output, output_ctx, skip_layers); + DiffusionParams diffusion_params, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL) { + return flux.compute(n_threads, + diffusion_params.x, + diffusion_params.timesteps, + diffusion_params.context, + diffusion_params.c_concat, + diffusion_params.y, + diffusion_params.guidance, + diffusion_params.ref_latents, + diffusion_params.increase_ref_index, + output, + output_ctx, + diffusion_params.skip_layers); } }; @@ -248,21 +246,20 @@ struct WanModel : public DiffusionModel { } void compute(int n_threads, - struct ggml_tensor* x, - struct ggml_tensor* timesteps, - struct ggml_tensor* context, - struct ggml_tensor* c_concat, - struct ggml_tensor* y, - struct ggml_tensor* guidance, - std::vector ref_latents = {}, - bool increase_ref_index = false, - int num_video_frames = -1, - std::vector controls = {}, - float control_strength = 0.f, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL, - std::vector skip_layers = std::vector()) { - return wan.compute(n_threads, x, timesteps, context, y, c_concat, NULL, output, output_ctx); + DiffusionParams diffusion_params, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL) { + return wan.compute(n_threads, + diffusion_params.x, + diffusion_params.timesteps, + diffusion_params.context, + diffusion_params.y, + diffusion_params.c_concat, + NULL, + diffusion_params.vace_context, + diffusion_params.vace_strength, + output, + output_ctx); } }; diff --git a/docs/chroma.md b/docs/chroma.md index 198b0453..5aac6444 100644 --- a/docs/chroma.md +++ b/docs/chroma.md @@ -24,7 +24,7 @@ You can download the preconverted gguf weights from [silveroxides/Chroma-GGUF](h For example: ``` - .\bin\Release\sd.exe --diffusion-model ..\models\chroma-unlocked-v40-q8_0.gguf --vae ..\models\ae.sft --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'chroma.cpp'" --cfg-scale 4.0 --sampling-method euler -v --chroma-disable-dit-mask + .\bin\Release\sd.exe --diffusion-model ..\models\chroma-unlocked-v40-q8_0.gguf --vae ..\models\ae.sft --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'chroma.cpp'" --cfg-scale 4.0 --sampling-method euler -v --chroma-disable-dit-mask --clip-on-cpu ``` ![](../assets/flux/chroma_v40.png) diff --git a/docs/flux.md b/docs/flux.md index dafad9b0..c1e8e6d2 100644 --- a/docs/flux.md +++ b/docs/flux.md @@ -28,7 +28,7 @@ Using fp16 will lead to overflow, but ggml's support for bf16 is not yet fully d For example: ``` - .\bin\Release\sd.exe --diffusion-model ..\models\flux1-dev-q8_0.gguf --vae ..\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v + .\bin\Release\sd.exe --diffusion-model ..\models\flux1-dev-q8_0.gguf --vae ..\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v --clip-on-cpu ``` Using formats of different precisions will yield results of varying quality. @@ -44,7 +44,7 @@ Using formats of different precisions will yield results of varying quality. ``` - .\bin\Release\sd.exe --diffusion-model ..\models\flux1-schnell-q8_0.gguf --vae ..\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v --steps 4 + .\bin\Release\sd.exe --diffusion-model ..\models\flux1-schnell-q8_0.gguf --vae ..\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v --steps 4 --clip-on-cpu ``` | q8_0 | @@ -60,7 +60,7 @@ Since many flux LoRA training libraries have used various LoRA naming formats, i - LoRA model from https://huggingface.co/XLabs-AI/flux-lora-collection/tree/main (using comfy converted version!!!) ``` -.\bin\Release\sd.exe --diffusion-model ..\models\flux1-dev-q8_0.gguf --vae ...\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v --lora-model-dir ../models +.\bin\Release\sd.exe --diffusion-model ..\models\flux1-dev-q8_0.gguf --vae ...\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v --lora-model-dir ../models --clip-on-cpu ``` ![output](../assets/flux/flux1-dev-q8_0%20with%20lora.png) diff --git a/docs/kontext.md b/docs/kontext.md index 69873503..58898066 100644 --- a/docs/kontext.md +++ b/docs/kontext.md @@ -27,7 +27,7 @@ You can download the preconverted gguf weights from [FLUX.1-Kontext-dev-GGUF](ht For example: ``` - .\bin\Release\sd.exe -r .\flux1-dev-q8_0.png --diffusion-model ..\models\flux1-kontext-dev-q8_0.gguf --vae ..\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "change 'flux.cpp' to 'kontext.cpp'" --cfg-scale 1.0 --sampling-method euler -v + .\bin\Release\sd.exe -r .\flux1-dev-q8_0.png --diffusion-model ..\models\flux1-kontext-dev-q8_0.gguf --vae ..\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "change 'flux.cpp' to 'kontext.cpp'" --cfg-scale 1.0 --sampling-method euler -v --clip-on-cpu ``` diff --git a/docs/photo_maker.md b/docs/photo_maker.md index 8305a33b..dae2c9b2 100644 --- a/docs/photo_maker.md +++ b/docs/photo_maker.md @@ -6,16 +6,15 @@ You can use [PhotoMaker](https://github.com/TencentARC/PhotoMaker) to personaliz Download PhotoMaker model file (in safetensor format) [here](https://huggingface.co/bssrdf/PhotoMaker). The official release of the model file (in .bin format) does not work with ```stablediffusion.cpp```. -- Specify the PhotoMaker model path using the `--stacked-id-embd-dir PATH` parameter. -- Specify the input images path using the `--input-id-images-dir PATH` parameter. - - input images **must** have the same width and height for preprocessing (to be improved) +- Specify the PhotoMaker model path using the `--photo-maker PATH` parameter. +- Specify the input images path using the `--pm-id-images-dir PATH` parameter. In prompt, make sure you have a class word followed by the trigger word ```"img"``` (hard-coded for now). The class word could be one of ```"man, woman, girl, boy"```. If input ID images contain asian faces, add ```Asian``` before the class word. Another PhotoMaker specific parameter: -- ```--style-ratio (0-100)%```: default is 20 and 10-20 typically gets good results. Lower ratio means more faithfully following input ID (not necessarily better quality). +- ```--pm-style-strength (0-100)%```: default is 20 and 10-20 typically gets good results. Lower ratio means more faithfully following input ID (not necessarily better quality). Other parameters recommended for running Photomaker: @@ -28,7 +27,7 @@ If on low memory GPUs (<= 8GB), recommend running with ```--vae-on-cpu``` option Example: ```bash -bin/sd -m ../models/sdxlUnstableDiffusers_v11.safetensors --vae ../models/sdxl_vae.safetensors --stacked-id-embd-dir ../models/photomaker-v1.safetensors --input-id-images-dir ../assets/photomaker_examples/scarletthead_woman -p "a girl img, retro futurism, retro game art style but extremely beautiful, intricate details, masterpiece, best quality, space-themed, cosmic, celestial, stars, galaxies, nebulas, planets, science fiction, highly detailed" -n "realistic, photo-realistic, worst quality, greyscale, bad anatomy, bad hands, error, text" --cfg-scale 5.0 --sampling-method euler -H 1024 -W 1024 --style-ratio 10 --vae-on-cpu -o output.png +bin/sd -m ../models/sdxlUnstableDiffusers_v11.safetensors --vae ../models/sdxl_vae.safetensors --photo-maker ../models/photomaker-v1.safetensors --pm-id-images-dir ../assets/photomaker_examples/scarletthead_woman -p "a girl img, retro futurism, retro game art style but extremely beautiful, intricate details, masterpiece, best quality, space-themed, cosmic, celestial, stars, galaxies, nebulas, planets, science fiction, highly detailed" -n "realistic, photo-realistic, worst quality, greyscale, bad anatomy, bad hands, error, text" --cfg-scale 5.0 --sampling-method euler -H 1024 -W 1024 --pm-style-strength 10 --vae-on-cpu --steps 50 ``` ## PhotoMaker Version 2 diff --git a/docs/sd3.md b/docs/sd3.md index 777511d4..2c1f8ff3 100644 --- a/docs/sd3.md +++ b/docs/sd3.md @@ -14,7 +14,7 @@ For example: ``` -.\bin\Release\sd.exe -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v +.\bin\Release\sd.exe -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v --clip-on-cpu ``` ![](../assets/sd3.5_large.png) \ No newline at end of file diff --git a/docs/wan.md b/docs/wan.md index a3e5d699..5bde71c7 100644 --- a/docs/wan.md +++ b/docs/wan.md @@ -18,6 +18,12 @@ - Wan2.1 FLF2V 14B 720P - safetensors: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models - gguf: https://huggingface.co/city96/Wan2.1-FLF2V-14B-720P-gguf/tree/main + - Wan2.1 VACE 1.3B + - safetensors: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models + - gguf: https://huggingface.co/calcuis/wan-1.3b-gguf/tree/main + - Wan2.1 VACE 14B + - safetensors: https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/tree/main/split_files/diffusion_models + - gguf: https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/tree/main - Wan2.2 - Wan2.2 TI2V 5B - safetensors: https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/tree/main/split_files/diffusion_models @@ -137,3 +143,62 @@ ``` + +### Wan2.1 VACE 1.3B + +#### T2V + +``` +.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\wan2.1-vace-1.3b-q8_0.gguf --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf -p "a lovely cat" --cfg-scale 6.0 --sampling-method euler -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部, 畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 832 -H 480 --diffusion-fa --video-frames 1 --offload-to-cpu +``` + + + + +#### R2V + +``` +.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\wan2.1-vace-1.3b-q8_0.gguf --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf -p "a lovely cat" --cfg-scale 6.0 --sampling-method euler -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部, 畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 832 -H 480 --diffusion-fa -i ..\assets\cat_with_sd_cpp_42.png --video-frames 33 --offload-to-cpu +``` + + + + +#### V2V + +``` +mkdir post+depth +ffmpeg -i ..\..\ComfyUI\input\post+depth.mp4 -qscale:v 1 -vf fps=8 post+depth\frame_%04d.jpg +.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\wan2.1-vace-1.3b-q8_0.gguf --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf -p "The girl is dancing in a sea of flowers, slowly moving her hands. There is a close - up shot of her upper body. The character is surrounded by other transparent glass flowers in the style of Nicoletta Ceccoli, creating a beautiful, surreal, and emotionally expressive movie scene with a white. transparent feel and a dreamyl atmosphere." --cfg-scale 6.0 --sampling-method euler -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部, 畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 480 -H 832 --diffusion-fa -i ..\..\ComfyUI\input\dance_girl.jpg --control-video ./post+depth --video-frames 33 --offload-to-cpu +``` + + + +### Wan2.1 VACE 14B + +#### T2V + +``` +.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\Wan2.1_14B_VACE-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf -p "a lovely cat" --cfg-scale 6.0 --sampling-method euler -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部, 畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 832 -H 480 --diffusion-fa --video-frames 33 --offload-to-cpu +``` + + + + +#### R2V + +``` +.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\Wan2.1_14B_VACE-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf -p "a lovely cat" --cfg-scale 6.0 --sampling-method euler -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部, 畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 832 -H 480 --diffusion-fa -i ..\assets\cat_with_sd_cpp_42.png --video-frames 33 --offload-to-cpu +``` + + + + + +#### V2V + +``` +.\bin\Release\sd.exe -M vid_gen --diffusion-model ..\..\ComfyUI\models\diffusion_models\Wan2.1_14B_VACE-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\wan_2.1_vae.safetensors --t5xxl ..\..\ComfyUI\models\text_encoders\umt5-xxl-encoder-Q8_0.gguf -p "The girl is dancing in a sea of flowers, slowly moving her hands. There is a close - up shot of her upper body. The character is surrounded by other transparent glass flowers in the style of Nicoletta Ceccoli, creating a beautiful, surreal, and emotionally expressive movie scene with a white. transparent feel and a dreamyl atmosphere." --cfg-scale 6.0 --sampling-method euler -v -n "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部, 畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" -W 480 -H 832 --diffusion-fa -i ..\..\ComfyUI\input\dance_girl.jpg --control-video ./post+depth --video-frames 33 --offload-to-cpu +``` + + diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 8dbf964e..0ba3acb7 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -35,6 +35,8 @@ #define SAFE_STR(s) ((s) ? (s) : "") #define BOOL_STR(b) ((b) ? "true" : "false") +namespace fs = std::filesystem; + const char* modes_str[] = { "img_gen", "vid_gen", @@ -64,8 +66,6 @@ struct SDParams { std::string esrgan_path; std::string control_net_path; std::string embedding_dir; - std::string stacked_id_embed_dir; - std::string input_id_images_path; sd_type_t wtype = SD_TYPE_COUNT; std::string tensor_type_rules; std::string lora_model_dir; @@ -75,15 +75,15 @@ struct SDParams { std::string mask_image_path; std::string control_image_path; std::vector ref_image_paths; + std::string control_video_path; bool increase_ref_index = false; std::string prompt; std::string negative_prompt; - float style_ratio = 20.f; - int clip_skip = -1; // <= 0 represents unspecified - int width = 512; - int height = 512; - int batch_count = 1; + int clip_skip = -1; // <= 0 represents unspecified + int width = 512; + int height = 512; + int batch_count = 1; std::vector skip_layers = {7, 8, 9}; sd_sample_params_t sample_params; @@ -91,17 +91,16 @@ struct SDParams { std::vector high_noise_skip_layers = {7, 8, 9}; sd_sample_params_t high_noise_sample_params; - float moe_boundary = 0.875f; - - int video_frames = 1; - int fps = 16; + float moe_boundary = 0.875f; + int video_frames = 1; + int fps = 16; + float vace_strength = 1.f; float strength = 0.75f; float control_strength = 0.9f; rng_type_t rng_type = CUDA_RNG; int64_t seed = 42; bool verbose = false; - bool vae_tiling = false; bool offload_params_to_cpu = false; bool control_net_cpu = false; bool normalize_input = false; @@ -114,11 +113,19 @@ struct SDParams { bool color = false; int upscale_repeats = 1; + // Photo Maker + std::string photo_maker_path; + std::string pm_id_images_dir; + std::string pm_id_embed_path; + float pm_style_strength = 20.f; + bool chroma_use_dit_mask = true; bool chroma_use_t5_mask = false; int chroma_t5_mask_pad = 1; float flow_shift = INFINITY; + sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; + SDParams() { sd_sample_params_init(&sample_params); sd_sample_params_init(&high_noise_sample_params); @@ -145,9 +152,10 @@ void print_params(SDParams params) { printf(" esrgan_path: %s\n", params.esrgan_path.c_str()); printf(" control_net_path: %s\n", params.control_net_path.c_str()); printf(" embedding_dir: %s\n", params.embedding_dir.c_str()); - printf(" stacked_id_embed_dir: %s\n", params.stacked_id_embed_dir.c_str()); - printf(" input_id_images_path: %s\n", params.input_id_images_path.c_str()); - printf(" style ratio: %.2f\n", params.style_ratio); + printf(" photo_maker_path: %s\n", params.photo_maker_path.c_str()); + printf(" pm_id_images_dir: %s\n", params.pm_id_images_dir.c_str()); + printf(" pm_id_embed_path: %s\n", params.pm_id_embed_path.c_str()); + printf(" pm_style_strength: %.2f\n", params.pm_style_strength); printf(" normalize input image: %s\n", params.normalize_input ? "true" : "false"); printf(" output_path: %s\n", params.output_path.c_str()); printf(" init_image_path: %s\n", params.init_image_path.c_str()); @@ -158,6 +166,7 @@ void print_params(SDParams params) { for (auto& path : params.ref_image_paths) { printf(" %s\n", path.c_str()); }; + printf(" control_video_path: %s\n", params.control_video_path.c_str()); printf(" increase_ref_index: %s\n", params.increase_ref_index ? "true" : "false"); printf(" offload_params_to_cpu: %s\n", params.offload_params_to_cpu ? "true" : "false"); printf(" clip_on_cpu: %s\n", params.clip_on_cpu ? "true" : "false"); @@ -178,14 +187,15 @@ void print_params(SDParams params) { printf(" flow_shift: %.2f\n", params.flow_shift); printf(" strength(img2img): %.2f\n", params.strength); printf(" rng: %s\n", sd_rng_type_name(params.rng_type)); - printf(" seed: %ld\n", params.seed); + printf(" seed: %zd\n", params.seed); printf(" batch_count: %d\n", params.batch_count); - printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false"); + printf(" vae_tiling: %s\n", params.vae_tiling_params.enabled ? "true" : "false"); printf(" upscale_repeats: %d\n", params.upscale_repeats); printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false"); printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false"); printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad); printf(" video_frames: %d\n", params.video_frames); + printf(" vace_strength: %.2f\n", params.vace_strength); printf(" fps: %d\n", params.fps); free(sample_params_str); free(high_noise_sample_params_str); @@ -211,9 +221,6 @@ void print_usage(int argc, const char* argv[]) { printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --control-net [CONTROL_PATH] path to control net model\n"); printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n"); - printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings\n"); - printf(" --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir\n"); - printf(" --normalize-input normalize PHOTOMAKER input id images\n"); printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n"); printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n"); printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n"); @@ -225,6 +232,9 @@ void print_usage(int argc, const char* argv[]) { printf(" -i, --end-img [IMAGE] path to the end image, required by flf2v\n"); printf(" --control-image [IMAGE] path to image condition, control net\n"); printf(" -r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times) \n"); + printf(" --control-video [PATH] path to control video frames, It must be a directory path.\n"); + printf(" The video frames inside should be stored as images in lexicographical (character) order\n"); + printf(" For example, if the control video path is `frames`, the directory contain images such as 00.png, 01.png, … etc.\n"); printf(" --increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).\n"); printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n"); printf(" -p, --prompt [PROMPT] the prompt to render\n"); @@ -240,7 +250,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --skip-layer-end END SLG disabling point: (default: 0.2)\n"); printf(" --scheduler {discrete, karras, exponential, ays, gits, smoothstep} Denoiser sigma scheduler (default: discrete)\n"); printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n"); - printf(" sampling method (default: \"euler_a\")\n"); + printf(" sampling method (default: \"euler\" for Flux/SD3/Wan, \"euler_a\" otherwise)\n"); printf(" --steps STEPS number of sample steps (default: 20)\n"); printf(" --high-noise-cfg-scale SCALE (high noise) unconditional guidance scale: (default: 7.0)\n"); printf(" --high-noise-img-cfg-scale SCALE (high noise) image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)\n"); @@ -257,7 +267,6 @@ void print_usage(int argc, const char* argv[]) { printf(" --high-noise-steps STEPS (high noise) number of sample steps (default: -1 = auto)\n"); printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n"); printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n"); - printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20)\n"); printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n"); printf(" 1.0 corresponds to full destruction of information in init image\n"); printf(" -H, --height H image height, in pixel space (default: 512)\n"); @@ -268,6 +277,9 @@ void print_usage(int argc, const char* argv[]) { printf(" --clip-skip N ignore last_dot_pos layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n"); printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n"); printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); + printf(" --vae-tile-size [X]x[Y] tile size for vae tiling (default: 32x32)\n"); + printf(" --vae-relative-tile-size [X]x[Y] relative tile size for vae tiling, in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)\n"); + printf(" --vae-tile-overlap OVERLAP tile overlap for vae tiling, in fraction of tile size (default: 0.5)\n"); printf(" --vae-on-cpu keep vae in cpu (for low vram)\n"); printf(" --clip-on-cpu keep clip in cpu (for low vram)\n"); printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n"); @@ -288,6 +300,12 @@ void print_usage(int argc, const char* argv[]) { printf(" --moe-boundary BOUNDARY timestep boundary for Wan2.2 MoE model. (default: 0.875)\n"); printf(" only enabled if `--high-noise-steps` is set to -1\n"); printf(" --flow-shift SHIFT shift value for Flow models like SD3.x or WAN (default: auto)\n"); + printf(" --vace-strength wan vace strength\n"); + printf(" --photo-maker path to PHOTOMAKER model\n"); + printf(" --pm-id-images-dir [DIR] path to PHOTOMAKER input id images dir\n"); + printf(" --pm-id-embed-path [PATH] path to PHOTOMAKER v2 id embed\n"); + printf(" --pm-style-strength strength for keeping PHOTOMAKER input identity (default: 20)\n"); + printf(" --normalize-input normalize PHOTOMAKER input id images\n"); printf(" -v, --verbose print extra info\n"); } @@ -474,18 +492,19 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"", "--taesd", "", ¶ms.taesd_path}, {"", "--control-net", "", ¶ms.control_net_path}, {"", "--embd-dir", "", ¶ms.embedding_dir}, - {"", "--stacked-id-embd-dir", "", ¶ms.stacked_id_embed_dir}, {"", "--lora-model-dir", "", ¶ms.lora_model_dir}, {"-i", "--init-img", "", ¶ms.init_image_path}, {"", "--end-img", "", ¶ms.end_image_path}, {"", "--tensor-type-rules", "", ¶ms.tensor_type_rules}, - {"", "--input-id-images-dir", "", ¶ms.input_id_images_path}, + {"", "--photo-maker", "", ¶ms.photo_maker_path}, + {"", "--pm-id-images-dir", "", ¶ms.pm_id_images_dir}, + {"", "--pm-id-embed-path", "", ¶ms.pm_id_embed_path}, {"", "--mask", "", ¶ms.mask_image_path}, {"", "--control-image", "", ¶ms.control_image_path}, + {"", "--control-video", "", ¶ms.control_video_path}, {"-o", "--output", "", ¶ms.output_path}, {"-p", "--prompt", "", ¶ms.prompt}, {"-n", "--negative-prompt", "", ¶ms.negative_prompt}, - {"", "--upscale-model", "", ¶ms.esrgan_path}, }; @@ -519,14 +538,16 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"", "--high-noise-skip-layer-end", "", ¶ms.high_noise_sample_params.guidance.slg.layer_end}, {"", "--high-noise-eta", "", ¶ms.high_noise_sample_params.eta}, {"", "--strength", "", ¶ms.strength}, - {"", "--style-ratio", "", ¶ms.style_ratio}, + {"", "--pm-style-strength", "", ¶ms.pm_style_strength}, {"", "--control-strength", "", ¶ms.control_strength}, {"", "--moe-boundary", "", ¶ms.moe_boundary}, {"", "--flow-shift", "", ¶ms.flow_shift}, + {"", "--vace-strength", "", ¶ms.vace_strength}, + {"", "--vae-tile-overlap", "", ¶ms.vae_tiling_params.target_overlap}, }; options.bool_options = { - {"", "--vae-tiling", "", true, ¶ms.vae_tiling}, + {"", "--vae-tiling", "", true, ¶ms.vae_tiling_params.enabled}, {"", "--offload-to-cpu", "", true, ¶ms.offload_params_to_cpu}, {"", "--control-net-cpu", "", true, ¶ms.control_net_cpu}, {"", "--normalize-input", "", true, ¶ms.normalize_input}, @@ -726,6 +747,52 @@ void parse_args(int argc, const char** argv, SDParams& params) { return 1; }; + auto on_tile_size_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string tile_size_str = argv[index]; + size_t x_pos = tile_size_str.find('x'); + try { + if (x_pos != std::string::npos) { + std::string tile_x_str = tile_size_str.substr(0, x_pos); + std::string tile_y_str = tile_size_str.substr(x_pos + 1); + params.vae_tiling_params.tile_size_x = std::stoi(tile_x_str); + params.vae_tiling_params.tile_size_y = std::stoi(tile_y_str); + } else { + params.vae_tiling_params.tile_size_x = params.vae_tiling_params.tile_size_y = std::stoi(tile_size_str); + } + } catch (const std::invalid_argument& e) { + return -1; + } catch (const std::out_of_range& e) { + return -1; + } + return 1; + }; + + auto on_relative_tile_size_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string rel_size_str = argv[index]; + size_t x_pos = rel_size_str.find('x'); + try { + if (x_pos != std::string::npos) { + std::string rel_x_str = rel_size_str.substr(0, x_pos); + std::string rel_y_str = rel_size_str.substr(x_pos + 1); + params.vae_tiling_params.rel_size_x = std::stof(rel_x_str); + params.vae_tiling_params.rel_size_y = std::stof(rel_y_str); + } else { + params.vae_tiling_params.rel_size_x = params.vae_tiling_params.rel_size_y = std::stof(rel_size_str); + } + } catch (const std::invalid_argument& e) { + return -1; + } catch (const std::out_of_range& e) { + return -1; + } + return 1; + }; + options.manual_options = { {"-M", "--mode", "", on_mode_arg}, {"", "--type", "", on_type_arg}, @@ -739,6 +806,8 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"", "--high-noise-skip-layers", "", on_high_noise_skip_layers_arg}, {"-r", "--ref-image", "", on_ref_image_arg}, {"-h", "--help", "", on_help_arg}, + {"", "--vae-tile-size", "", on_tile_size_arg}, + {"", "--vae-relative-tile-size", "", on_relative_tile_size_arg}, }; if (!parse_options(argc, argv, options)) { @@ -1012,14 +1081,58 @@ uint8_t* load_image(const char* image_path, int& width, int& height, int expecte STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, STBIR_FILTER_BOX, STBIR_FILTER_BOX, STBIR_COLORSPACE_SRGB, nullptr); - - // Save resized result + width = resized_width; + height = resized_height; free(image_buffer); image_buffer = resized_image_buffer; } return image_buffer; } +bool load_images_from_dir(const std::string dir, + std::vector& images, + int expected_width = 0, + int expected_height = 0, + int max_image_num = 0, + bool verbose = false) { + if (!fs::exists(dir) || !fs::is_directory(dir)) { + fprintf(stderr, "'%s' is not a valid directory\n", dir.c_str()); + return false; + } + + for (const auto& entry : fs::directory_iterator(dir)) { + if (!entry.is_regular_file()) + continue; + + std::string path = entry.path().string(); + std::string ext = entry.path().extension().string(); + std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); + + if (ext == ".jpg" || ext == ".jpeg" || ext == ".png" || ext == ".bmp") { + if (verbose) { + printf("load image %zu from '%s'\n", images.size(), path.c_str()); + } + int width = 0; + int height = 0; + uint8_t* image_buffer = load_image(path.c_str(), width, height, expected_width, expected_height); + if (image_buffer == NULL) { + fprintf(stderr, "load image from '%s' failed\n", path.c_str()); + return false; + } + + images.push_back({(uint32_t)width, + (uint32_t)height, + 3, + image_buffer}); + + if (max_image_num > 0 && images.size() >= max_image_num) { + break; + } + } + } + return true; +} + int main(int argc, const char* argv[]) { SDParams params; parse_args(argc, argv, params); @@ -1059,17 +1172,29 @@ int main(int argc, const char* argv[]) { sd_image_t control_image = {(uint32_t)params.width, (uint32_t)params.height, 3, NULL}; sd_image_t mask_image = {(uint32_t)params.width, (uint32_t)params.height, 1, NULL}; std::vector ref_images; + std::vector pmid_images; + std::vector control_frames; auto release_all_resources = [&]() { free(init_image.data); free(end_image.data); free(control_image.data); free(mask_image.data); - for (auto ref_image : ref_images) { - free(ref_image.data); - ref_image.data = NULL; + for (auto image : ref_images) { + free(image.data); + image.data = NULL; } ref_images.clear(); + for (auto image : pmid_images) { + free(image.data); + image.data = NULL; + } + pmid_images.clear(); + for (auto image : control_frames) { + free(image.data); + image.data = NULL; + } + control_frames.clear(); }; if (params.init_image_path.size() > 0) { @@ -1128,14 +1253,12 @@ int main(int argc, const char* argv[]) { return 1; } if (params.canny_preprocess) { // apply preprocessor - control_image.data = preprocess_canny(control_image.data, - control_image.width, - control_image.height, - 0.08f, - 0.08f, - 0.8f, - 1.0f, - false); + preprocess_canny(control_image, + 0.08f, + 0.08f, + 0.8f, + 1.0f, + false); } } @@ -1157,6 +1280,30 @@ int main(int argc, const char* argv[]) { } } + if (!params.control_video_path.empty()) { + if (!load_images_from_dir(params.control_video_path, + control_frames, + params.width, + params.height, + params.video_frames, + params.verbose)) { + release_all_resources(); + return 1; + } + } + + if (!params.pm_id_images_dir.empty()) { + if (!load_images_from_dir(params.pm_id_images_dir, + pmid_images, + 0, + 0, + 0, + params.verbose)) { + release_all_resources(); + return 1; + } + } + if (params.mode == VID_GEN) { vae_decode_only = false; } @@ -1174,9 +1321,8 @@ int main(int argc, const char* argv[]) { params.control_net_path.c_str(), params.lora_model_dir.c_str(), params.embedding_dir.c_str(), - params.stacked_id_embed_dir.c_str(), + params.photo_maker_path.c_str(), vae_decode_only, - params.vae_tiling, true, params.n_threads, params.wtype, @@ -1202,6 +1348,10 @@ int main(int argc, const char* argv[]) { return 1; } + if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) { + params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx); + } + sd_image_t* results; int num_results = 1; if (params.mode == IMG_GEN) { @@ -1222,9 +1372,14 @@ int main(int argc, const char* argv[]) { params.batch_count, control_image, params.control_strength, - params.style_ratio, params.normalize_input, - params.input_id_images_path.c_str(), + { + pmid_images.data(), + (int)pmid_images.size(), + params.pm_id_embed_path.c_str(), + params.pm_style_strength, + }, // pm_params + params.vae_tiling_params, }; results = generate_image(sd_ctx, &img_gen_params); @@ -1236,6 +1391,8 @@ int main(int argc, const char* argv[]) { params.clip_skip, init_image, end_image, + control_frames.data(), + (int)control_frames.size(), params.width, params.height, params.sample_params, @@ -1244,6 +1401,7 @@ int main(int argc, const char* argv[]) { params.strength, params.seed, params.video_frames, + params.vace_strength, }; results = generate_video(sd_ctx, &vid_gen_params, &num_results); @@ -1286,7 +1444,6 @@ int main(int argc, const char* argv[]) { // create directory if not exists { - namespace fs = std::filesystem; const fs::path out_path = params.output_path; if (const fs::path out_dir = out_path.parent_path(); !out_dir.empty()) { std::error_code ec; diff --git a/ggml_extend.hpp b/ggml_extend.hpp index b88344e1..a5f61ea4 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -185,17 +185,17 @@ __STATIC_INLINE__ ggml_fp16_t ggml_tensor_get_f16(const ggml_tensor* tensor, int return *(ggml_fp16_t*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]); } -static struct ggml_tensor* get_tensor_from_graph(struct ggml_cgraph* gf, const char* name) { - struct ggml_tensor* res = NULL; - for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { - struct ggml_tensor* node = ggml_graph_node(gf, i); - // printf("%d, %s \n", i, ggml_get_name(node)); - if (strcmp(ggml_get_name(node), name) == 0) { - res = node; - break; - } +__STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int iw, int ih, int ic, bool scale = true) { + float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic); + if (scale) { + value /= 255.f; } - return res; + return value; +} + +__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic) { + float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic); + return value; } __STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_only = false, const char* mark = "") { @@ -235,6 +235,52 @@ __STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_ } } +__STATIC_INLINE__ void ggml_tensor_iter( + ggml_tensor* tensor, + const std::function& fn) { + int64_t n0 = tensor->ne[0]; + int64_t n1 = tensor->ne[1]; + int64_t n2 = tensor->ne[2]; + int64_t n3 = tensor->ne[3]; + + for (int64_t i3 = 0; i3 < n3; i3++) { + for (int64_t i2 = 0; i2 < n2; i2++) { + for (int64_t i1 = 0; i1 < n1; i1++) { + for (int64_t i0 = 0; i0 < n0; i0++) { + fn(tensor, i0, i1, i2, i3); + } + } + } + } +} + +__STATIC_INLINE__ void ggml_tensor_iter( + ggml_tensor* tensor, + const std::function& fn) { + int64_t n0 = tensor->ne[0]; + int64_t n1 = tensor->ne[1]; + int64_t n2 = tensor->ne[2]; + int64_t n3 = tensor->ne[3]; + + for (int64_t i = 0; i < ggml_nelements(tensor); i++) { + fn(tensor, i); + } +} + +__STATIC_INLINE__ void ggml_tensor_diff( + ggml_tensor* a, + ggml_tensor* b, + float gap = 0.1f) { + GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b)); + ggml_tensor_iter(a, [&](ggml_tensor* a, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float a_value = ggml_tensor_get_f32(a, i0, i1, i2, i3); + float b_value = ggml_tensor_get_f32(b, i0, i1, i2, i3); + if (abs(a_value - b_value) > gap) { + LOG_WARN("[%ld, %ld, %ld, %ld] %f %f", i3, i2, i1, i0, a_value, b_value); + } + }); +} + __STATIC_INLINE__ ggml_tensor* load_tensor_from_file(ggml_context* ctx, const std::string& file_path) { std::ifstream file(file_path, std::ios::binary); if (!file.is_open()) { @@ -366,42 +412,18 @@ __STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input, int idx return image_data; } -__STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data, - struct ggml_tensor* output, +__STATIC_INLINE__ void sd_image_to_tensor(sd_image_t image, + ggml_tensor* tensor, bool scale = true) { - int64_t width = output->ne[0]; - int64_t height = output->ne[1]; - int64_t channels = output->ne[2]; - GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32); - for (int iy = 0; iy < height; iy++) { - for (int ix = 0; ix < width; ix++) { - for (int k = 0; k < channels; k++) { - float value = *(image_data + iy * width * channels + ix * channels + k); - if (scale) { - value /= 255.f; - } - ggml_tensor_set_f32(output, value, ix, iy, k); - } - } - } -} - -__STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data, - struct ggml_tensor* output, - bool scale = true) { - int64_t width = output->ne[0]; - int64_t height = output->ne[1]; - int64_t channels = output->ne[2]; - GGML_ASSERT(channels == 1 && output->type == GGML_TYPE_F32); - for (int iy = 0; iy < height; iy++) { - for (int ix = 0; ix < width; ix++) { - float value = *(image_data + iy * width * channels + ix); - if (scale) { - value /= 255.f; - } - ggml_tensor_set_f32(output, value, ix, iy); - } - } + GGML_ASSERT(image.width == tensor->ne[0]); + GGML_ASSERT(image.height == tensor->ne[1]); + GGML_ASSERT(image.channel == tensor->ne[2]); + GGML_ASSERT(1 == tensor->ne[3]); + GGML_ASSERT(tensor->type == GGML_TYPE_F32); + ggml_tensor_iter(tensor, [&](ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = sd_image_get_f32(image, i0, i1, i2, scale); + ggml_tensor_set_f32(tensor, value, i0, i1, i2, i3); + }); } __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data, @@ -424,28 +446,6 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data, } } -__STATIC_INLINE__ void sd_mul_images_to_tensor(const uint8_t* image_data, - struct ggml_tensor* output, - int idx, - float* mean = NULL, - float* std = NULL) { - int64_t width = output->ne[0]; - int64_t height = output->ne[1]; - int64_t channels = output->ne[2]; - GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32); - for (int iy = 0; iy < height; iy++) { - for (int ix = 0; ix < width; ix++) { - for (int k = 0; k < channels; k++) { - int value = *(image_data + iy * width * channels + ix * channels + k); - float pixel_val = value / 255.0f; - if (mean != NULL && std != NULL) - pixel_val = (pixel_val - mean[k]) / std[k]; - ggml_tensor_set_f32(output, pixel_val, ix, iy, k, idx); - } - } - } -} - __STATIC_INLINE__ void sd_image_f32_to_tensor(const float* image_data, struct ggml_tensor* output, bool scale = true) { @@ -494,7 +494,10 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input, struct ggml_tensor* output, int x, int y, - int overlap) { + int overlap_x, + int overlap_y, + int x_skip = 0, + int y_skip = 0) { int64_t width = input->ne[0]; int64_t height = input->ne[1]; int64_t channels = input->ne[2]; @@ -503,17 +506,17 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input, int64_t img_height = output->ne[1]; GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32); - for (int iy = 0; iy < height; iy++) { - for (int ix = 0; ix < width; ix++) { + for (int iy = y_skip; iy < height; iy++) { + for (int ix = x_skip; ix < width; ix++) { for (int k = 0; k < channels; k++) { float new_value = ggml_tensor_get_f32(input, ix, iy, k); - if (overlap > 0) { // blend colors in overlapped area + if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k); - const float x_f_0 = (x > 0) ? ix / float(overlap) : 1; - const float x_f_1 = (x < (img_width - width)) ? (width - ix) / float(overlap) : 1; - const float y_f_0 = (y > 0) ? iy / float(overlap) : 1; - const float y_f_1 = (y < (img_height - height)) ? (height - iy) / float(overlap) : 1; + const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1; + const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1; + const float y_f_0 = (overlap_y > 0 && y > 0) ? (iy - y_skip) / float(overlap_y) : 1; + const float y_f_1 = (overlap_y > 0 && y < (img_height - height)) ? (height - iy) / float(overlap_y) : 1; const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f); const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f); @@ -745,22 +748,102 @@ __STATIC_INLINE__ std::vector ggml_chunk(struct ggml_contex typedef std::function on_tile_process; +__STATIC_INLINE__ void sd_tiling_calc_tiles(int& num_tiles_dim, + float& tile_overlap_factor_dim, + int small_dim, + int tile_size, + const float tile_overlap_factor) { + int tile_overlap = (tile_size * tile_overlap_factor); + int non_tile_overlap = tile_size - tile_overlap; + + num_tiles_dim = (small_dim - tile_overlap) / non_tile_overlap; + int overshoot_dim = ((num_tiles_dim + 1) * non_tile_overlap + tile_overlap) % small_dim; + + if ((overshoot_dim != non_tile_overlap) && (overshoot_dim <= num_tiles_dim * (tile_size / 2 - tile_overlap))) { + // if tiles don't fit perfectly using the desired overlap + // and there is enough room to squeeze an extra tile without overlap becoming >0.5 + num_tiles_dim++; + } + + tile_overlap_factor_dim = (float)(tile_size * num_tiles_dim - small_dim) / (float)(tile_size * (num_tiles_dim - 1)); + if (num_tiles_dim <= 2) { + if (small_dim <= tile_size) { + num_tiles_dim = 1; + tile_overlap_factor_dim = 0; + } else { + num_tiles_dim = 2; + tile_overlap_factor_dim = (2 * tile_size - small_dim) / (float)tile_size; + } + } +} + // Tiling -__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) { +__STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, + ggml_tensor* output, + const int scale, + const int p_tile_size_x, + const int p_tile_size_y, + const float tile_overlap_factor, + on_tile_process on_processing) { output = ggml_set_f32(output, 0); int input_width = (int)input->ne[0]; int input_height = (int)input->ne[1]; int output_width = (int)output->ne[0]; int output_height = (int)output->ne[1]; + + GGML_ASSERT(((input_width / output_width) == (input_height / output_height)) && + ((output_width / input_width) == (output_height / input_height))); + GGML_ASSERT(((input_width / output_width) == scale) || + ((output_width / input_width) == scale)); + + int small_width = output_width; + int small_height = output_height; + + bool decode = output_width > input_width; + if (decode) { + small_width = input_width; + small_height = input_height; + } + + int num_tiles_x; + float tile_overlap_factor_x; + sd_tiling_calc_tiles(num_tiles_x, tile_overlap_factor_x, small_width, p_tile_size_x, tile_overlap_factor); + + int num_tiles_y; + float tile_overlap_factor_y; + sd_tiling_calc_tiles(num_tiles_y, tile_overlap_factor_y, small_height, p_tile_size_y, tile_overlap_factor); + + LOG_DEBUG("num tiles : %d, %d ", num_tiles_x, num_tiles_y); + LOG_DEBUG("optimal overlap : %f, %f (targeting %f)", tile_overlap_factor_x, tile_overlap_factor_y, tile_overlap_factor); + GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2 - int tile_overlap = (int32_t)(tile_size * tile_overlap_factor); - int non_tile_overlap = tile_size - tile_overlap; + int tile_overlap_x = (int32_t)(p_tile_size_x * tile_overlap_factor_x); + int non_tile_overlap_x = p_tile_size_x - tile_overlap_x; + + int tile_overlap_y = (int32_t)(p_tile_size_y * tile_overlap_factor_y); + int non_tile_overlap_y = p_tile_size_y - tile_overlap_y; + + int tile_size_x = p_tile_size_x < small_width ? p_tile_size_x : small_width; + int tile_size_y = p_tile_size_y < small_height ? p_tile_size_y : small_height; + + int input_tile_size_x = tile_size_x; + int input_tile_size_y = tile_size_y; + int output_tile_size_x = tile_size_x; + int output_tile_size_y = tile_size_y; + + if (decode) { + output_tile_size_x *= scale; + output_tile_size_y *= scale; + } else { + input_tile_size_x *= scale; + input_tile_size_y *= scale; + } struct ggml_init_params params = {}; - params.mem_size += tile_size * tile_size * input->ne[2] * sizeof(float); // input chunk - params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne[2] * sizeof(float); // output chunk + params.mem_size += input_tile_size_x * input_tile_size_y * input->ne[2] * sizeof(float); // input chunk + params.mem_size += output_tile_size_x * output_tile_size_y * output->ne[2] * sizeof(float); // output chunk params.mem_size += 3 * ggml_tensor_overhead(); params.mem_buffer = NULL; params.no_alloc = false; @@ -775,29 +858,50 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const } // tiling - ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size, tile_size, input->ne[2], 1); - ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale, output->ne[2], 1); - on_processing(input_tile, NULL, true); - int num_tiles = ceil((float)input_width / non_tile_overlap) * ceil((float)input_height / non_tile_overlap); + ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], 1); + ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], 1); + int num_tiles = num_tiles_x * num_tiles_y; LOG_INFO("processing %i tiles", num_tiles); - pretty_progress(1, num_tiles, 0.0f); + pretty_progress(0, num_tiles, 0.0f); int tile_count = 1; bool last_y = false, last_x = false; float last_time = 0.0f; - for (int y = 0; y < input_height && !last_y; y += non_tile_overlap) { - if (y + tile_size >= input_height) { - y = input_height - tile_size; + for (int y = 0; y < small_height && !last_y; y += non_tile_overlap_y) { + int dy = 0; + if (y + tile_size_y >= small_height) { + int _y = y; + y = small_height - tile_size_y; + dy = _y - y; + if (decode) { + dy *= scale; + } last_y = true; } - for (int x = 0; x < input_width && !last_x; x += non_tile_overlap) { - if (x + tile_size >= input_width) { - x = input_width - tile_size; + for (int x = 0; x < small_width && !last_x; x += non_tile_overlap_x) { + int dx = 0; + if (x + tile_size_x >= small_width) { + int _x = x; + x = small_width - tile_size_x; + dx = _x - x; + if (decode) { + dx *= scale; + } last_x = true; } + + int x_in = decode ? x : scale * x; + int y_in = decode ? y : scale * y; + int x_out = decode ? x * scale : x; + int y_out = decode ? y * scale : y; + + int overlap_x_out = decode ? tile_overlap_x * scale : tile_overlap_x; + int overlap_y_out = decode ? tile_overlap_y * scale : tile_overlap_y; + int64_t t1 = ggml_time_ms(); - ggml_split_tensor_2d(input, input_tile, x, y); + ggml_split_tensor_2d(input, input_tile, x_in, y_in); on_processing(input_tile, output_tile, false); - ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap * scale); + ggml_merge_tensor_2d(output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, dx, dy); + int64_t t2 = ggml_time_ms(); last_time = (t2 - t1) / 1000.0f; pretty_progress(tile_count, num_tiles, last_time); @@ -811,6 +915,15 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const ggml_free(tiles_ctx); } +__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, + ggml_tensor* output, + const int scale, + const int tile_size, + const float tile_overlap_factor, + on_tile_process on_processing) { + sd_tiling_non_square(input, output, scale, tile_size, tile_size, tile_overlap_factor, on_processing); +} + __STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx, struct ggml_tensor* a) { const float eps = 1e-6f; // default eps parameter @@ -1523,6 +1636,7 @@ struct GGMLRunner { ggml_backend_tensor_copy(t, offload_t); std::swap(t->buffer, offload_t->buffer); std::swap(t->data, offload_t->data); + std::swap(t->extra, offload_t->extra); t = ggml_get_next_tensor(params_ctx, t); offload_t = ggml_get_next_tensor(offload_ctx, offload_t); @@ -1553,8 +1667,10 @@ struct GGMLRunner { while (t != NULL && offload_t != NULL) { t->buffer = offload_t->buffer; t->data = offload_t->data; + t->extra = offload_t->extra; offload_t->buffer = NULL; offload_t->data = NULL; + offload_t->extra = NULL; t = ggml_get_next_tensor(params_ctx, t); offload_t = ggml_get_next_tensor(offload_ctx, offload_t); diff --git a/lora.hpp b/lora.hpp index b7a27306..222f61b1 100644 --- a/lora.hpp +++ b/lora.hpp @@ -1,6 +1,7 @@ #ifndef __LORA_HPP__ #define __LORA_HPP__ +#include #include "ggml_extend.hpp" #define LORA_GRAPH_BASE_SIZE 10240 @@ -115,7 +116,7 @@ struct LoraModel : public GGMLRunner { return "lora"; } - bool load_from_file(bool filter_tensor = false) { + bool load_from_file(bool filter_tensor = false, int n_threads = 0) { LOG_INFO("loading LoRA from '%s'", file_path.c_str()); if (load_failed) { @@ -123,41 +124,53 @@ struct LoraModel : public GGMLRunner { return false; } + std::unordered_map tensors_to_create; + std::mutex lora_mutex; bool dry_run = true; auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { - const std::string& name = tensor_storage.name; + if (dry_run) { + const std::string& name = tensor_storage.name; - if (filter_tensor && !contains(name, "lora")) { - // LOG_INFO("skipping LoRA tesnor '%s'", name.c_str()); - return true; - } - // LOG_INFO("lora_tensor %s", name.c_str()); - for (int i = 0; i < LORA_TYPE_COUNT; i++) { - if (name.find(type_fingerprints[i]) != std::string::npos) { - type = (lora_t)i; - break; + if (filter_tensor && !contains(name, "lora")) { + return true; } - } - if (dry_run) { - struct ggml_tensor* real = ggml_new_tensor(params_ctx, - tensor_storage.type, - tensor_storage.n_dims, - tensor_storage.ne); - lora_tensors[name] = real; + { + std::lock_guard lock(lora_mutex); + for (int i = 0; i < LORA_TYPE_COUNT; i++) { + if (name.find(type_fingerprints[i]) != std::string::npos) { + type = (lora_t)i; + break; + } + } + tensors_to_create[name] = tensor_storage; + } } else { - auto real = lora_tensors[name]; - *dst_tensor = real; + const std::string& name = tensor_storage.name; + auto iter = lora_tensors.find(name); + if (iter != lora_tensors.end()) { + *dst_tensor = iter->second; + } } - return true; }; - model_loader.load_tensors(on_new_tensor_cb); + model_loader.load_tensors(on_new_tensor_cb, n_threads); + + for (const auto& pair : tensors_to_create) { + const auto& name = pair.first; + const auto& ts = pair.second; + struct ggml_tensor* real = ggml_new_tensor(params_ctx, + ts.type, + ts.n_dims, + ts.ne); + lora_tensors[name] = real; + } + alloc_params_buffer(); - // exit(0); + dry_run = false; - model_loader.load_tensors(on_new_tensor_cb); + model_loader.load_tensors(on_new_tensor_cb, n_threads); LOG_DEBUG("lora type: \"%s\"/\"%s\"", lora_downs[type].c_str(), lora_ups[type].c_str()); diff --git a/model.cpp b/model.cpp index 4e42018c..168b675d 100644 --- a/model.cpp +++ b/model.cpp @@ -1,8 +1,13 @@ #include +#include +#include #include +#include +#include #include #include #include +#include #include #include @@ -107,7 +112,7 @@ const char* unused_tensors[] = { }; bool is_unused_tensor(std::string name) { - for (int i = 0; i < sizeof(unused_tensors) / sizeof(const char*); i++) { + for (size_t i = 0; i < sizeof(unused_tensors) / sizeof(const char*); i++) { if (starts_with(name, unused_tensors[i])) { return true; } @@ -1944,292 +1949,344 @@ std::string ModelLoader::load_umt5_tokenizer_json() { return json_str; } -std::vector remove_duplicates(const std::vector& vec) { - std::vector res; - std::unordered_map name_to_index_map; +bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads_p) { + int64_t process_time_ms = 0; + std::atomic read_time_ms(0); + std::atomic memcpy_time_ms(0); + std::atomic copy_to_backend_time_ms(0); + std::atomic convert_time_ms(0); - for (size_t i = 0; i < vec.size(); ++i) { - const std::string& current_name = vec[i].name; - auto it = name_to_index_map.find(current_name); + int num_threads_to_use = n_threads_p > 0 ? n_threads_p : (int)std::thread::hardware_concurrency(); - if (it != name_to_index_map.end()) { - res[it->second] = vec[i]; - } else { - name_to_index_map[current_name] = i; - res.push_back(vec[i]); + int64_t start_time = ggml_time_ms(); + std::vector processed_tensor_storages; + + { + struct IndexedStorage { + size_t index; + TensorStorage ts; + }; + + std::mutex vec_mutex; + std::vector all_results; + + int n_threads = std::min(num_threads_to_use, (int)tensor_storages.size()); + if (n_threads < 1) { + n_threads = 1; } - } + std::vector workers; - // vec.resize(name_to_index_map.size()); + for (int i = 0; i < n_threads; ++i) { + workers.emplace_back([&, thread_id = i]() { + std::vector local_results; + std::vector temp_storages; - return res; -} + for (size_t j = thread_id; j < tensor_storages.size(); j += n_threads) { + const auto& tensor_storage = tensor_storages[j]; + if (is_unused_tensor(tensor_storage.name)) { + continue; + } -bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) { - int64_t process_time_ms = 0; - int64_t read_time_ms = 0; - int64_t memcpy_time_ms = 0; - int64_t copy_to_backend_time_ms = 0; - int64_t convert_time_ms = 0; - - int64_t prev_time_ms = 0; - int64_t curr_time_ms = 0; - int64_t start_time = ggml_time_ms(); - prev_time_ms = start_time; - std::vector processed_tensor_storages; - for (auto& tensor_storage : tensor_storages) { - // LOG_DEBUG("%s", name.c_str()); + temp_storages.clear(); + preprocess_tensor(tensor_storage, temp_storages); - if (is_unused_tensor(tensor_storage.name)) { - continue; + for (const auto& ts : temp_storages) { + local_results.push_back({j, ts}); + } + } + + if (!local_results.empty()) { + std::lock_guard lock(vec_mutex); + all_results.insert(all_results.end(), + local_results.begin(), local_results.end()); + } + }); + } + for (auto& w : workers) { + w.join(); } - preprocess_tensor(tensor_storage, processed_tensor_storages); + std::unordered_map latest_map; + for (auto& entry : all_results) { + latest_map[entry.ts.name] = entry; + } + + processed_tensor_storages.reserve(latest_map.size()); + for (auto& [name, entry] : latest_map) { + processed_tensor_storages.push_back(entry.ts); + } } - std::vector dedup = remove_duplicates(processed_tensor_storages); - processed_tensor_storages = dedup; - curr_time_ms = ggml_time_ms(); - process_time_ms = curr_time_ms - prev_time_ms; - prev_time_ms = curr_time_ms; - bool success = true; + process_time_ms = ggml_time_ms() - start_time; + + bool success = true; + size_t total_tensors_processed = 0; + const size_t total_tensors_to_process = processed_tensor_storages.size(); + const int64_t t_start = ggml_time_ms(); + int last_n_threads = 1; + for (size_t file_index = 0; file_index < file_paths_.size(); file_index++) { std::string file_path = file_paths_[file_index]; LOG_DEBUG("loading tensors from %s", file_path.c_str()); - std::ifstream file(file_path, std::ios::binary); - if (!file.is_open()) { - LOG_ERROR("failed to open '%s'", file_path.c_str()); - return false; + std::vector file_tensors; + for (const auto& ts : processed_tensor_storages) { + if (ts.file_index == file_index) { + file_tensors.push_back(&ts); + } + } + if (file_tensors.empty()) { + continue; } bool is_zip = false; - for (auto& tensor_storage : tensor_storages) { - if (tensor_storage.file_index != file_index) { - continue; - } - if (tensor_storage.index_in_zip >= 0) { + for (auto const& ts : file_tensors) { + if (ts->index_in_zip >= 0) { is_zip = true; break; } } - struct zip_t* zip = NULL; - if (is_zip) { - zip = zip_open(file_path.c_str(), 0, 'r'); - if (zip == NULL) { - LOG_ERROR("failed to open zip '%s'", file_path.c_str()); - return false; - } + int n_threads = is_zip ? 1 : std::min(num_threads_to_use, (int)file_tensors.size()); + if (n_threads < 1) { + n_threads = 1; } + last_n_threads = n_threads; - std::vector read_buffer; - std::vector convert_buffer; - - auto read_data = [&](const TensorStorage& tensor_storage, char* buf, size_t n) { - if (zip != NULL) { - zip_entry_openbyindex(zip, tensor_storage.index_in_zip); - size_t entry_size = zip_entry_size(zip); - if (entry_size != n) { - read_buffer.resize(entry_size); - prev_time_ms = ggml_time_ms(); - zip_entry_noallocread(zip, (void*)read_buffer.data(), entry_size); - curr_time_ms = ggml_time_ms(); - read_time_ms += curr_time_ms - prev_time_ms; - prev_time_ms = curr_time_ms; - memcpy((void*)buf, (void*)(read_buffer.data() + tensor_storage.offset), n); - curr_time_ms = ggml_time_ms(); - memcpy_time_ms += curr_time_ms - prev_time_ms; + std::atomic tensor_idx(0); + std::atomic failed(false); + std::vector workers; + + for (int i = 0; i < n_threads; ++i) { + workers.emplace_back([&, file_path, is_zip]() { + std::ifstream file; + struct zip_t* zip = NULL; + if (is_zip) { + zip = zip_open(file_path.c_str(), 0, 'r'); + if (zip == NULL) { + LOG_ERROR("failed to open zip '%s'", file_path.c_str()); + failed = true; + return; + } } else { - prev_time_ms = ggml_time_ms(); - zip_entry_noallocread(zip, (void*)buf, n); - curr_time_ms = ggml_time_ms(); - read_time_ms += curr_time_ms - prev_time_ms; - } - zip_entry_close(zip); - } else { - prev_time_ms = ggml_time_ms(); - file.seekg(tensor_storage.offset); - file.read(buf, n); - curr_time_ms = ggml_time_ms(); - read_time_ms += curr_time_ms - prev_time_ms; - if (!file) { - LOG_ERROR("read tensor data failed: '%s'", file_path.c_str()); - return false; + file.open(file_path, std::ios::binary); + if (!file.is_open()) { + LOG_ERROR("failed to open '%s'", file_path.c_str()); + failed = true; + return; + } } - } - return true; - }; - int tensor_count = 0; - int64_t t0 = ggml_time_ms(); - int64_t t1 = t0; - bool partial = true; - int tensor_max = (int)processed_tensor_storages.size(); - pretty_progress(0, tensor_max, 0.0f); - for (auto& tensor_storage : processed_tensor_storages) { - if (tensor_storage.file_index != file_index) { - ++tensor_count; - continue; - } - ggml_tensor* dst_tensor = NULL; - success = on_new_tensor_cb(tensor_storage, &dst_tensor); - if (!success) { - LOG_WARN("process tensor failed: '%s'", tensor_storage.name.c_str()); - break; - } + std::vector read_buffer; + std::vector convert_buffer; - if (dst_tensor == NULL) { - ++tensor_count; - continue; - } + while (true) { + int64_t t0, t1; + size_t idx = tensor_idx.fetch_add(1); + if (idx >= file_tensors.size() || failed) { + break; + } - size_t nbytes_to_read = tensor_storage.nbytes_to_read(); + const TensorStorage& tensor_storage = *file_tensors[idx]; + ggml_tensor* dst_tensor = NULL; - if (dst_tensor->buffer == NULL || ggml_backend_buffer_is_host(dst_tensor->buffer)) { - // for the CPU and Metal backend, we can copy directly into the tensor - if (tensor_storage.type == dst_tensor->type) { - GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes()); - if (tensor_storage.is_f64 || tensor_storage.is_i64) { - read_buffer.resize(tensor_storage.nbytes_to_read()); - read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read); - } else { - read_data(tensor_storage, (char*)dst_tensor->data, nbytes_to_read); - } + t0 = ggml_time_ms(); - prev_time_ms = ggml_time_ms(); - if (tensor_storage.is_bf16) { - // inplace op - bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements()); - } else if (tensor_storage.is_f8_e4m3) { - // inplace op - f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements()); - } else if (tensor_storage.is_f8_e5m2) { - // inplace op - f8_e5m2_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements()); - } else if (tensor_storage.is_f64) { - f64_to_f32_vec((double*)read_buffer.data(), (float*)dst_tensor->data, tensor_storage.nelements()); - } else if (tensor_storage.is_i64) { - i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)dst_tensor->data, tensor_storage.nelements()); + if (!on_new_tensor_cb(tensor_storage, &dst_tensor)) { + LOG_WARN("process tensor failed: '%s'", tensor_storage.name.c_str()); + failed = true; + break; } - curr_time_ms = ggml_time_ms(); - convert_time_ms += curr_time_ms - prev_time_ms; - } else { - read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read())); - read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read); - - prev_time_ms = ggml_time_ms(); - if (tensor_storage.is_bf16) { - // inplace op - bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_f8_e4m3) { - // inplace op - f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_f8_e5m2) { - // inplace op - f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_f64) { - // inplace op - f64_to_f32_vec((double*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_i64) { - // inplace op - i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)read_buffer.data(), tensor_storage.nelements()); + + if (dst_tensor == NULL) { + t1 = ggml_time_ms(); + read_time_ms.fetch_add(t1 - t0); + continue; } - convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data, - dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]); - curr_time_ms = ggml_time_ms(); - convert_time_ms += curr_time_ms - prev_time_ms; - } - } else { - read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read())); - read_data(tensor_storage, (char*)read_buffer.data(), nbytes_to_read); - - prev_time_ms = ggml_time_ms(); - if (tensor_storage.is_bf16) { - // inplace op - bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_f8_e4m3) { - // inplace op - f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_f8_e5m2) { - // inplace op - f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_f64) { - // inplace op - f64_to_f32_vec((double*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); - } else if (tensor_storage.is_i64) { - // inplace op - i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)read_buffer.data(), tensor_storage.nelements()); - } + size_t nbytes_to_read = tensor_storage.nbytes_to_read(); + + auto read_data = [&](char* buf, size_t n) { + if (zip != NULL) { + zip_entry_openbyindex(zip, tensor_storage.index_in_zip); + size_t entry_size = zip_entry_size(zip); + if (entry_size != n) { + int64_t t_memcpy_start; + read_buffer.resize(entry_size); + zip_entry_noallocread(zip, (void*)read_buffer.data(), entry_size); + t_memcpy_start = ggml_time_ms(); + memcpy((void*)buf, (void*)(read_buffer.data() + tensor_storage.offset), n); + memcpy_time_ms.fetch_add(ggml_time_ms() - t_memcpy_start); + } else { + zip_entry_noallocread(zip, (void*)buf, n); + } + zip_entry_close(zip); + } else { + file.seekg(tensor_storage.offset); + file.read(buf, n); + if (!file) { + LOG_ERROR("read tensor data failed: '%s'", file_path.c_str()); + failed = true; + } + } + }; + + if (dst_tensor->buffer == NULL || ggml_backend_buffer_is_host(dst_tensor->buffer)) { + if (tensor_storage.type == dst_tensor->type) { + GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes()); + if (tensor_storage.is_f64 || tensor_storage.is_i64) { + read_buffer.resize(tensor_storage.nbytes_to_read()); + read_data((char*)read_buffer.data(), nbytes_to_read); + } else { + read_data((char*)dst_tensor->data, nbytes_to_read); + } + t1 = ggml_time_ms(); + read_time_ms.fetch_add(t1 - t0); + + t0 = ggml_time_ms(); + if (tensor_storage.is_bf16) { + // inplace op + bf16_to_f32_vec((uint16_t*)dst_tensor->data, (float*)dst_tensor->data, tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e4m3) { + // inplace op + f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e5m2) { + // inplace op + f8_e5m2_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements()); + } else if (tensor_storage.is_f64) { + f64_to_f32_vec((double*)read_buffer.data(), (float*)dst_tensor->data, tensor_storage.nelements()); + } else if (tensor_storage.is_i64) { + i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)dst_tensor->data, tensor_storage.nelements()); + } + t1 = ggml_time_ms(); + convert_time_ms.fetch_add(t1 - t0); + } else { + read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read())); + read_data((char*)read_buffer.data(), nbytes_to_read); + t1 = ggml_time_ms(); + read_time_ms.fetch_add(t1 - t0); + + t0 = ggml_time_ms(); + if (tensor_storage.is_bf16) { + // inplace op + bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e4m3) { + // inplace op + f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e5m2) { + // inplace op + f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_f64) { + // inplace op + f64_to_f32_vec((double*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_i64) { + // inplace op + i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)read_buffer.data(), tensor_storage.nelements()); + } + convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data, dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]); + t1 = ggml_time_ms(); + convert_time_ms.fetch_add(t1 - t0); + } + } else { + read_buffer.resize(std::max(tensor_storage.nbytes(), tensor_storage.nbytes_to_read())); + read_data((char*)read_buffer.data(), nbytes_to_read); + t1 = ggml_time_ms(); + read_time_ms.fetch_add(t1 - t0); + + t0 = ggml_time_ms(); + if (tensor_storage.is_bf16) { + // inplace op + bf16_to_f32_vec((uint16_t*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e4m3) { + // inplace op + f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e5m2) { + // inplace op + f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_f64) { + // inplace op + f64_to_f32_vec((double*)read_buffer.data(), (float*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_i64) { + // inplace op + i64_to_i32_vec((int64_t*)read_buffer.data(), (int32_t*)read_buffer.data(), tensor_storage.nelements()); + } - if (tensor_storage.type == dst_tensor->type) { - // copy to device memory - curr_time_ms = ggml_time_ms(); - convert_time_ms += curr_time_ms - prev_time_ms; - prev_time_ms = curr_time_ms; - ggml_backend_tensor_set(dst_tensor, read_buffer.data(), 0, ggml_nbytes(dst_tensor)); - curr_time_ms = ggml_time_ms(); - copy_to_backend_time_ms += curr_time_ms - prev_time_ms; - } else { - // convert first, then copy to device memory - convert_buffer.resize(ggml_nbytes(dst_tensor)); - convert_tensor((void*)read_buffer.data(), tensor_storage.type, - (void*)convert_buffer.data(), dst_tensor->type, - (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]); - curr_time_ms = ggml_time_ms(); - convert_time_ms += curr_time_ms - prev_time_ms; - prev_time_ms = curr_time_ms; - ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor)); - curr_time_ms = ggml_time_ms(); - copy_to_backend_time_ms += curr_time_ms - prev_time_ms; + if (tensor_storage.type == dst_tensor->type) { + // copy to device memory + t1 = ggml_time_ms(); + convert_time_ms.fetch_add(t1 - t0); + t0 = ggml_time_ms(); + ggml_backend_tensor_set(dst_tensor, read_buffer.data(), 0, ggml_nbytes(dst_tensor)); + t1 = ggml_time_ms(); + copy_to_backend_time_ms.fetch_add(t1 - t0); + } else { + // convert first, then copy to device memory + + convert_buffer.resize(ggml_nbytes(dst_tensor)); + convert_tensor((void*)read_buffer.data(), tensor_storage.type, (void*)convert_buffer.data(), dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]); + t1 = ggml_time_ms(); + convert_time_ms.fetch_add(t1 - t0); + t0 = ggml_time_ms(); + ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor)); + t1 = ggml_time_ms(); + copy_to_backend_time_ms.fetch_add(t1 - t0); + } + } } - } - ++tensor_count; - int64_t t2 = ggml_time_ms(); - if ((t2 - t1) >= 200) { - t1 = t2; - pretty_progress(tensor_count, tensor_max, (t1 - t0) / (1000.0f * tensor_count)); - partial = tensor_count != tensor_max; - } + if (zip != NULL) { + zip_close(zip); + } + }); } - if (partial) { - if (tensor_count >= 1) { - t1 = ggml_time_ms(); - pretty_progress(tensor_count, tensor_max, (t1 - t0) / (1000.0f * tensor_count)); - } - if (tensor_count < tensor_max) { - printf("\n"); + while (true) { + size_t current_idx = tensor_idx.load(); + if (current_idx >= file_tensors.size() || failed) { + break; } + size_t curr_num = total_tensors_processed + current_idx; + pretty_progress(curr_num, total_tensors_to_process, (ggml_time_ms() - t_start) / 1000.0f / (curr_num + 1e-6f)); + std::this_thread::sleep_for(std::chrono::milliseconds(200)); } - if (zip != NULL) { - zip_close(zip); + for (auto& w : workers) { + w.join(); } - if (!success) { + if (failed) { + success = false; break; } + total_tensors_processed += file_tensors.size(); + pretty_progress(total_tensors_processed, total_tensors_to_process, (ggml_time_ms() - t_start) / 1000.0f / (total_tensors_processed + 1e-6f)); + if (total_tensors_processed < total_tensors_to_process) { + printf("\n"); + } } + int64_t end_time = ggml_time_ms(); LOG_INFO("loading tensors completed, taking %.2fs (process: %.2fs, read: %.2fs, memcpy: %.2fs, convert: %.2fs, copy_to_backend: %.2fs)", (end_time - start_time) / 1000.f, process_time_ms / 1000.f, - read_time_ms / 1000.f, - memcpy_time_ms / 1000.f, - convert_time_ms / 1000.f, - copy_to_backend_time_ms / 1000.f); + (read_time_ms.load() / (float)last_n_threads) / 1000.f, + (memcpy_time_ms.load() / (float)last_n_threads) / 1000.f, + (convert_time_ms.load() / (float)last_n_threads) / 1000.f, + (copy_to_backend_time_ms.load() / (float)last_n_threads) / 1000.f); return success; } bool ModelLoader::load_tensors(std::map& tensors, - std::set ignore_tensors) { + std::set ignore_tensors, + int n_threads) { std::set tensor_names_in_file; + std::mutex tensor_names_mutex; auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { const std::string& name = tensor_storage.name; // LOG_DEBUG("%s", tensor_storage.to_string().c_str()); - tensor_names_in_file.insert(name); + { + std::lock_guard lock(tensor_names_mutex); + tensor_names_in_file.insert(name); + } struct ggml_tensor* real; if (tensors.find(name) != tensors.end()) { @@ -2263,7 +2320,7 @@ bool ModelLoader::load_tensors(std::map& tenso return true; }; - bool success = load_tensors(on_new_tensor_cb); + bool success = load_tensors(on_new_tensor_cb, n_threads); if (!success) { LOG_ERROR("load tensors from file failed"); return false; @@ -2310,7 +2367,7 @@ std::vector> parse_tensor_type_rules(const std if (type_name == "f32") { tensor_type = GGML_TYPE_F32; } else { - for (size_t i = 0; i < SD_TYPE_COUNT; i++) { + for (size_t i = 0; i < GGML_TYPE_COUNT; i++) { auto trait = ggml_get_type_traits((ggml_type)i); if (trait->to_float && trait->type_size && type_name == trait->type_name) { tensor_type = (ggml_type)i; diff --git a/model.h b/model.h index fef6ace8..0fdc99c0 100644 --- a/model.h +++ b/model.h @@ -119,7 +119,7 @@ struct TensorStorage { size_t file_index = 0; int index_in_zip = -1; // >= means stored in a zip file - size_t offset = 0; // offset in file + uint64_t offset = 0; // offset in file TensorStorage() = default; @@ -164,10 +164,10 @@ struct TensorStorage { std::vector chunk(size_t n) { std::vector chunks; - size_t chunk_size = nbytes_to_read() / n; + uint64_t chunk_size = nbytes_to_read() / n; // printf("%d/%d\n", chunk_size, nbytes_to_read()); reverse_ne(); - for (int i = 0; i < n; i++) { + for (size_t i = 0; i < n; i++) { TensorStorage chunk_i = *this; chunk_i.ne[0] = ne[0] / n; chunk_i.offset = offset + i * chunk_size; @@ -247,9 +247,10 @@ class ModelLoader { ggml_type get_diffusion_model_wtype(); ggml_type get_vae_wtype(); void set_wtype_override(ggml_type wtype, std::string prefix = ""); - bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb); + bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0); bool load_tensors(std::map& tensors, - std::set ignore_tensors = {}); + std::set ignore_tensors = {}, + int n_threads = 0); bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules); bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); diff --git a/pmid.hpp b/pmid.hpp index 5e9b0d5b..d7daa419 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -42,41 +42,6 @@ struct FuseBlock : public GGMLBlock { } }; -/* -class QFormerPerceiver(nn.Module): - def __init__(self, id_embeddings_dim, cross_attention_dim, num_tokens, embedding_dim=1024, use_residual=True, ratio=4): - super().__init__() - - self.num_tokens = num_tokens - self.cross_attention_dim = cross_attention_dim - self.use_residual = use_residual - print(cross_attention_dim*num_tokens) - self.token_proj = nn.Sequential( - nn.Linear(id_embeddings_dim, id_embeddings_dim*ratio), - nn.GELU(), - nn.Linear(id_embeddings_dim*ratio, cross_attention_dim*num_tokens), - ) - self.token_norm = nn.LayerNorm(cross_attention_dim) - self.perceiver_resampler = FacePerceiverResampler( - dim=cross_attention_dim, - depth=4, - dim_head=128, - heads=cross_attention_dim // 128, - embedding_dim=embedding_dim, - output_dim=cross_attention_dim, - ff_mult=4, - ) - - def forward(self, x, last_hidden_state): - x = self.token_proj(x) - x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) - x = self.token_norm(x) # cls token - out = self.perceiver_resampler(x, last_hidden_state) # retrieve from patch tokens - if self.use_residual: # TODO: if use_residual is not true - out = x + 1.0 * out - return out -*/ - struct PMFeedForward : public GGMLBlock { // network hparams int dim; @@ -122,17 +87,8 @@ struct PerceiverAttention : public GGMLBlock { int64_t ne[4]; for (int i = 0; i < 4; ++i) ne[i] = x->ne[i]; - // print_ggml_tensor(x, true, "PerceiverAttention reshape x 0: "); - // printf("heads = %d \n", heads); - // x = ggml_view_4d(ctx, x, x->ne[0], x->ne[1], heads, x->ne[2]/heads, - // x->nb[1], x->nb[2], x->nb[3], 0); x = ggml_reshape_4d(ctx, x, x->ne[0] / heads, heads, x->ne[1], x->ne[2]); - // x = ggml_view_4d(ctx, x, x->ne[0]/heads, heads, x->ne[1], x->ne[2], - // x->nb[1], x->nb[2], x->nb[3], 0); - // x = ggml_cont(ctx, x); x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); - // print_ggml_tensor(x, true, "PerceiverAttention reshape x 1: "); - // x = ggml_reshape_4d(ctx, x, ne[0], heads, ne[1], ne[2]/heads); return x; } @@ -269,17 +225,6 @@ struct QFormerPerceiver : public GGMLBlock { 4)); } - /* - def forward(self, x, last_hidden_state): - x = self.token_proj(x) - x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) - x = self.token_norm(x) # cls token - out = self.perceiver_resampler(x, last_hidden_state) # retrieve from patch tokens - if self.use_residual: # TODO: if use_residual is not true - out = x + 1.0 * out - return out - */ - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* last_hidden_state) { @@ -299,113 +244,6 @@ struct QFormerPerceiver : public GGMLBlock { } }; -/* -class FacePerceiverResampler(torch.nn.Module): - def __init__( - self, - *, - dim=768, - depth=4, - dim_head=64, - heads=16, - embedding_dim=1280, - output_dim=768, - ff_mult=4, - ): - super().__init__() - - self.proj_in = torch.nn.Linear(embedding_dim, dim) - self.proj_out = torch.nn.Linear(dim, output_dim) - self.norm_out = torch.nn.LayerNorm(output_dim) - self.layers = torch.nn.ModuleList([]) - for _ in range(depth): - self.layers.append( - torch.nn.ModuleList( - [ - PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), - FeedForward(dim=dim, mult=ff_mult), - ] - ) - ) - - def forward(self, latents, x): - x = self.proj_in(x) - for attn, ff in self.layers: - latents = attn(x, latents) + latents - latents = ff(latents) + latents - latents = self.proj_out(latents) - return self.norm_out(latents) -*/ - -/* - -def FeedForward(dim, mult=4): - inner_dim = int(dim * mult) - return nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, inner_dim, bias=False), - nn.GELU(), - nn.Linear(inner_dim, dim, bias=False), - ) - -def reshape_tensor(x, heads): - bs, length, width = x.shape - # (bs, length, width) --> (bs, length, n_heads, dim_per_head) - x = x.view(bs, length, heads, -1) - # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) - x = x.transpose(1, 2) - # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) - x = x.reshape(bs, heads, length, -1) - return x - -class PerceiverAttention(nn.Module): - def __init__(self, *, dim, dim_head=64, heads=8): - super().__init__() - self.scale = dim_head**-0.5 - self.dim_head = dim_head - self.heads = heads - inner_dim = dim_head * heads - - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) - self.to_out = nn.Linear(inner_dim, dim, bias=False) - - def forward(self, x, latents): - """ - Args: - x (torch.Tensor): image features - shape (b, n1, D) - latent (torch.Tensor): latent features - shape (b, n2, D) - """ - x = self.norm1(x) - latents = self.norm2(latents) - - b, l, _ = latents.shape - - q = self.to_q(latents) - kv_input = torch.cat((x, latents), dim=-2) - k, v = self.to_kv(kv_input).chunk(2, dim=-1) - - q = reshape_tensor(q, self.heads) - k = reshape_tensor(k, self.heads) - v = reshape_tensor(v, self.heads) - - # attention - scale = 1 / math.sqrt(math.sqrt(self.dim_head)) - weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - out = weight @ v - - out = out.permute(0, 2, 1, 3).reshape(b, l, -1) - - return self.to_out(out) - -*/ - struct FuseModule : public GGMLBlock { // network hparams int embed_dim; @@ -425,31 +263,13 @@ struct FuseModule : public GGMLBlock { auto mlp2 = std::dynamic_pointer_cast(blocks["mlp2"]); auto layer_norm = std::dynamic_pointer_cast(blocks["layer_norm"]); - // print_ggml_tensor(id_embeds, true, "Fuseblock id_embeds: "); - // print_ggml_tensor(prompt_embeds, true, "Fuseblock prompt_embeds: "); - - // auto prompt_embeds0 = ggml_cont(ctx, ggml_permute(ctx, prompt_embeds, 2, 0, 1, 3)); - // auto id_embeds0 = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3)); - // print_ggml_tensor(id_embeds0, true, "Fuseblock id_embeds0: "); - // print_ggml_tensor(prompt_embeds0, true, "Fuseblock prompt_embeds0: "); - // concat is along dim 2 - // auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds0, id_embeds0, 2); auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds, id_embeds, 0); - // print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 0: "); - // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 1, 2, 0, 3)); - // print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 1: "); - // stacked_id_embeds = mlp1.forward(ctx, stacked_id_embeds); - // stacked_id_embeds = ggml_add(ctx, stacked_id_embeds, prompt_embeds); - // stacked_id_embeds = mlp2.forward(ctx, stacked_id_embeds); - // stacked_id_embeds = ggml_nn_layer_norm(ctx, stacked_id_embeds, ln_w, ln_b); stacked_id_embeds = mlp1->forward(ctx, stacked_id_embeds); stacked_id_embeds = ggml_add(ctx, stacked_id_embeds, prompt_embeds); stacked_id_embeds = mlp2->forward(ctx, stacked_id_embeds); stacked_id_embeds = layer_norm->forward(ctx, stacked_id_embeds); - // print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 1: "); - return stacked_id_embeds; } @@ -464,21 +284,14 @@ struct FuseModule : public GGMLBlock { struct ggml_tensor* valid_id_embeds = id_embeds; // # slice out the image token embeddings - // print_ggml_tensor(class_tokens_mask_pos, false); ggml_set_name(class_tokens_mask_pos, "class_tokens_mask_pos"); ggml_set_name(prompt_embeds, "prompt_embeds"); - // print_ggml_tensor(valid_id_embeds, true, "valid_id_embeds"); - // print_ggml_tensor(class_tokens_mask_pos, true, "class_tokens_mask_pos"); struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx, prompt_embeds, class_tokens_mask_pos); ggml_set_name(image_token_embeds, "image_token_embeds"); valid_id_embeds = ggml_reshape_2d(ctx, valid_id_embeds, valid_id_embeds->ne[0], ggml_nelements(valid_id_embeds) / valid_id_embeds->ne[0]); struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds); - // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); - // print_ggml_tensor(stacked_id_embeds, true, "AA stacked_id_embeds"); - // print_ggml_tensor(left, true, "AA left"); - // print_ggml_tensor(right, true, "AA right"); if (left && right) { stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1); stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1); @@ -487,15 +300,12 @@ struct FuseModule : public GGMLBlock { } else if (right) { stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1); } - // print_ggml_tensor(stacked_id_embeds, true, "BB stacked_id_embeds"); - // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); - // print_ggml_tensor(stacked_id_embeds, true, "CC stacked_id_embeds"); + class_tokens_mask = ggml_cont(ctx, ggml_transpose(ctx, class_tokens_mask)); class_tokens_mask = ggml_repeat(ctx, class_tokens_mask, prompt_embeds); prompt_embeds = ggml_mul(ctx, prompt_embeds, class_tokens_mask); struct ggml_tensor* updated_prompt_embeds = ggml_add(ctx, prompt_embeds, stacked_id_embeds); ggml_set_name(updated_prompt_embeds, "updated_prompt_embeds"); - // print_ggml_tensor(updated_prompt_embeds, true, "updated_prompt_embeds: "); return updated_prompt_embeds; } }; @@ -551,34 +361,11 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo num_tokens(2) { blocks["visual_projection_2"] = std::shared_ptr(new Linear(1024, 1280, false)); blocks["fuse_module"] = std::shared_ptr(new FuseModule(2048)); - /* - cross_attention_dim = 2048 - # projection - self.num_tokens = 2 - self.cross_attention_dim = cross_attention_dim - self.qformer_perceiver = QFormerPerceiver( - id_embeddings_dim, - cross_attention_dim, - self.num_tokens, - )*/ - blocks["qformer_perceiver"] = std::shared_ptr(new QFormerPerceiver(id_embeddings_dim, - cross_attention_dim, - num_tokens)); + blocks["qformer_perceiver"] = std::shared_ptr(new QFormerPerceiver(id_embeddings_dim, + cross_attention_dim, + num_tokens)); } - /* - def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds): - b, num_inputs, c, h, w = id_pixel_values.shape - id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) - - last_hidden_state = self.vision_model(id_pixel_values)[0] - id_embeds = id_embeds.view(b * num_inputs, -1) - - id_embeds = self.qformer_perceiver(id_embeds, last_hidden_state) - id_embeds = id_embeds.view(b, num_inputs, self.num_tokens, -1) - updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) - */ - struct ggml_tensor* forward(struct ggml_context* ctx, ggml_backend_t backend, struct ggml_tensor* id_pixel_values, diff --git a/preprocessing.hpp b/preprocessing.hpp index 4ea1dbab..9cace2f4 100644 --- a/preprocessing.hpp +++ b/preprocessing.hpp @@ -162,16 +162,16 @@ void threshold_hystersis(struct ggml_tensor* img, float high_threshold, float lo } } -uint8_t* preprocess_canny(uint8_t* img, int width, int height, float high_threshold, float low_threshold, float weak, float strong, bool inverse) { +bool preprocess_canny(sd_image_t img, float high_threshold, float low_threshold, float weak, float strong, bool inverse) { struct ggml_init_params params; - params.mem_size = static_cast(10 * 1024 * 1024); // 10 + params.mem_size = static_cast(10 * 1024 * 1024); // 10MB params.mem_buffer = NULL; params.no_alloc = false; struct ggml_context* work_ctx = ggml_init(params); if (!work_ctx) { LOG_ERROR("ggml_init() failed"); - return NULL; + return false; } float kX[9] = { @@ -192,8 +192,8 @@ uint8_t* preprocess_canny(uint8_t* img, int width, int height, float high_thresh struct ggml_tensor* sf_ky = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 3, 3, 1, 1); memcpy(sf_ky->data, kY, ggml_nbytes(sf_ky)); gaussian_kernel(gkernel); - struct ggml_tensor* image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); - struct ggml_tensor* image_gray = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 1, 1); + struct ggml_tensor* image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, img.width, img.height, 3, 1); + struct ggml_tensor* image_gray = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, img.width, img.height, 1, 1); struct ggml_tensor* iX = ggml_dup_tensor(work_ctx, image_gray); struct ggml_tensor* iY = ggml_dup_tensor(work_ctx, image_gray); struct ggml_tensor* G = ggml_dup_tensor(work_ctx, image_gray); @@ -209,8 +209,8 @@ uint8_t* preprocess_canny(uint8_t* img, int width, int height, float high_thresh non_max_supression(image_gray, G, tetha); threshold_hystersis(image_gray, high_threshold, low_threshold, weak, strong); // to RGB channels - for (int iy = 0; iy < height; iy++) { - for (int ix = 0; ix < width; ix++) { + for (int iy = 0; iy < img.height; iy++) { + for (int ix = 0; ix < img.width; ix++) { float gray = ggml_tensor_get_f32(image_gray, ix, iy); gray = inverse ? 1.0f - gray : gray; ggml_tensor_set_f32(image, gray, ix, iy); @@ -218,10 +218,11 @@ uint8_t* preprocess_canny(uint8_t* img, int width, int height, float high_thresh ggml_tensor_set_f32(image, gray, ix, iy, 2); } } - free(img); uint8_t* output = sd_tensor_to_image(image); + free(img.data); + img.data = output; ggml_free(work_ctx); - return output; + return true; } #endif // __PREPROCESSING_HPP__ \ No newline at end of file diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index db4e07cb..ccd90a00 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -43,7 +43,7 @@ const char* model_version_to_str[] = { }; const char* sampling_methods_str[] = { - "Euler A", + "default", "Euler", "Heun", "DPM2", @@ -55,6 +55,7 @@ const char* sampling_methods_str[] = { "LCM", "DDIM \"trailing\"", "TCD", + "Euler A", }; /*================================================== Helper Functions ================================================*/ @@ -107,10 +108,10 @@ class StableDiffusionGGML { std::shared_ptr pmid_id_embeds; std::string taesd_path; - bool use_tiny_autoencoder = false; - bool vae_tiling = false; - bool offload_params_to_cpu = false; - bool stacked_id = false; + bool use_tiny_autoencoder = false; + sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0, 0}; + bool offload_params_to_cpu = false; + bool stacked_id = false; bool is_using_v_parameterization = false; bool is_using_edm_v_parameterization = false; @@ -182,7 +183,6 @@ class StableDiffusionGGML { lora_model_dir = SAFE_STR(sd_ctx_params->lora_model_dir); taesd_path = SAFE_STR(sd_ctx_params->taesd_path); use_tiny_autoencoder = taesd_path.size() > 0; - vae_tiling = sd_ctx_params->vae_tiling; offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu; if (sd_ctx_params->rng_type == STD_DEFAULT_RNG) { @@ -265,7 +265,9 @@ class StableDiffusionGGML { } LOG_INFO("Version: %s ", model_version_to_str[version]); - ggml_type wtype = (ggml_type)sd_ctx_params->wtype; + ggml_type wtype = (int)sd_ctx_params->wtype < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT) + ? (ggml_type)sd_ctx_params->wtype + : GGML_TYPE_COUNT; if (wtype == GGML_TYPE_COUNT) { model_wtype = model_loader.get_sd_wtype(); if (model_wtype == GGML_TYPE_COUNT) { @@ -293,11 +295,6 @@ class StableDiffusionGGML { model_loader.set_wtype_override(wtype); } - if (sd_version_is_sdxl(version)) { - vae_wtype = GGML_TYPE_F32; - model_loader.set_wtype_override(GGML_TYPE_F32, "vae."); - } - LOG_INFO("Weight type: %s", ggml_type_name(model_wtype)); LOG_INFO("Conditioner weight type: %s", ggml_type_name(conditioner_wtype)); LOG_INFO("Diffusion model weight type: %s", ggml_type_name(diffusion_model_wtype)); @@ -373,7 +370,6 @@ class StableDiffusionGGML { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, model_loader.tensor_storages_types, - -1, sd_ctx_params->chroma_use_t5_mask, sd_ctx_params->chroma_t5_mask_pad); } else { @@ -391,7 +387,6 @@ class StableDiffusionGGML { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, model_loader.tensor_storages_types, - -1, true, 1, true); @@ -417,7 +412,7 @@ class StableDiffusionGGML { clip_vision->get_param_tensors(tensors); } } else { // SD1.x SD2.x SDXL - if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) { + if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { cond_stage_model = std::make_shared(clip_backend, offload_params_to_cpu, model_loader.tensor_storages_types, @@ -515,7 +510,7 @@ class StableDiffusionGGML { } } - if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) { + if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) { pmid_model = std::make_shared(backend, offload_params_to_cpu, model_loader.tensor_storages_types, @@ -530,15 +525,15 @@ class StableDiffusionGGML { "pmid", version); } - if (strlen(SAFE_STR(sd_ctx_params->stacked_id_embed_dir)) > 0) { - pmid_lora = std::make_shared(backend, sd_ctx_params->stacked_id_embed_dir, ""); + if (strlen(SAFE_STR(sd_ctx_params->photo_maker_path)) > 0) { + pmid_lora = std::make_shared(backend, sd_ctx_params->photo_maker_path, ""); if (!pmid_lora->load_from_file(true)) { - LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->stacked_id_embed_dir); + LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->photo_maker_path); return false; } - LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", sd_ctx_params->stacked_id_embed_dir); - if (!model_loader.init_from_file(sd_ctx_params->stacked_id_embed_dir, "pmid.")) { - LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->stacked_id_embed_dir); + LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", sd_ctx_params->photo_maker_path); + if (!model_loader.init_from_file(sd_ctx_params->photo_maker_path, "pmid.")) { + LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->photo_maker_path); } else { stacked_id = true; } @@ -581,7 +576,7 @@ class StableDiffusionGGML { if (version == VERSION_SVD) { ignore_tensors.insert("conditioner.embedders.3"); } - bool success = model_loader.load_tensors(tensors, ignore_tensors); + bool success = model_loader.load_tensors(tensors, ignore_tensors, n_threads); if (!success) { LOG_ERROR("load tensors from model loader failed"); ggml_free(ctx); @@ -781,7 +776,12 @@ class StableDiffusionGGML { int64_t t0 = ggml_time_ms(); struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t); - diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, {}, false, -1, {}, 0.f, &out); + DiffusionParams diffusion_params; + diffusion_params.x = x_t; + diffusion_params.timesteps = timesteps; + diffusion_params.context = c; + diffusion_params.c_concat = concat; + diffusion_model->compute(n_threads, diffusion_params, &out); diffusion_model->free_compute_buffer(); double result = 0.f; @@ -959,7 +959,7 @@ class StableDiffusionGGML { free(resized_image.data); resized_image.data = NULL; } else { - sd_image_to_tensor(init_image.data, init_img); + sd_image_to_tensor(init_image, init_img); } if (augmentation_level > 0.f) { struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, init_img); @@ -1039,7 +1039,9 @@ class StableDiffusionGGML { SDCondition id_cond, std::vector ref_latents = {}, bool increase_ref_index = false, - ggml_tensor* denoise_mask = nullptr) { + ggml_tensor* denoise_mask = NULL, + ggml_tensor* vace_context = NULL, + float vace_strength = 1.f) { std::vector skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count); float cfg_scale = guidance.txt_cfg; @@ -1123,34 +1125,31 @@ class StableDiffusionGGML { // GGML_ASSERT(0); } + DiffusionParams diffusion_params; + diffusion_params.x = noised_input; + diffusion_params.timesteps = timesteps; + diffusion_params.guidance = guidance_tensor; + diffusion_params.ref_latents = ref_latents; + diffusion_params.increase_ref_index = increase_ref_index; + diffusion_params.controls = controls; + diffusion_params.control_strength = control_strength; + diffusion_params.vace_context = vace_context; + diffusion_params.vace_strength = vace_strength; + if (start_merge_step == -1 || step <= start_merge_step) { // cond + diffusion_params.context = cond.c_crossattn; + diffusion_params.c_concat = cond.c_concat; + diffusion_params.y = cond.c_vector; work_diffusion_model->compute(n_threads, - noised_input, - timesteps, - cond.c_crossattn, - cond.c_concat, - cond.c_vector, - guidance_tensor, - ref_latents, - increase_ref_index, - -1, - controls, - control_strength, + diffusion_params, &out_cond); } else { + diffusion_params.context = id_cond.c_crossattn; + diffusion_params.c_concat = cond.c_concat; + diffusion_params.y = id_cond.c_vector; work_diffusion_model->compute(n_threads, - noised_input, - timesteps, - id_cond.c_crossattn, - cond.c_concat, - id_cond.c_vector, - guidance_tensor, - ref_latents, - increase_ref_index, - -1, - controls, - control_strength, + diffusion_params, &out_cond); } @@ -1161,36 +1160,23 @@ class StableDiffusionGGML { control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector); controls = control_net->controls; } + diffusion_params.controls = controls; + diffusion_params.context = uncond.c_crossattn; + diffusion_params.c_concat = uncond.c_concat; + diffusion_params.y = uncond.c_vector; work_diffusion_model->compute(n_threads, - noised_input, - timesteps, - uncond.c_crossattn, - uncond.c_concat, - uncond.c_vector, - guidance_tensor, - ref_latents, - increase_ref_index, - -1, - controls, - control_strength, + diffusion_params, &out_uncond); negative_data = (float*)out_uncond->data; } float* img_cond_data = NULL; if (has_img_cond) { + diffusion_params.context = img_cond.c_crossattn; + diffusion_params.c_concat = img_cond.c_concat; + diffusion_params.y = img_cond.c_vector; work_diffusion_model->compute(n_threads, - noised_input, - timesteps, - img_cond.c_crossattn, - img_cond.c_concat, - img_cond.c_vector, - guidance_tensor, - ref_latents, - increase_ref_index, - -1, - controls, - control_strength, + diffusion_params, &out_img_cond); img_cond_data = (float*)out_img_cond->data; } @@ -1201,21 +1187,13 @@ class StableDiffusionGGML { if (is_skiplayer_step) { LOG_DEBUG("Skipping layers at step %d\n", step); // skip layer (same as conditionned) + diffusion_params.context = cond.c_crossattn; + diffusion_params.c_concat = cond.c_concat; + diffusion_params.y = cond.c_vector; + diffusion_params.skip_layers = skip_layers; work_diffusion_model->compute(n_threads, - noised_input, - timesteps, - cond.c_crossattn, - cond.c_concat, - cond.c_vector, - guidance_tensor, - ref_latents, - increase_ref_index, - -1, - controls, - control_strength, - &out_skip, - NULL, - skip_layers); + diffusion_params, + &out_skip); skip_layer_data = (float*)out_skip->data; } float* vec_denoised = (float*)denoised->data; @@ -1301,15 +1279,77 @@ class StableDiffusionGGML { return latent; } - ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) { + void get_tile_sizes(int& tile_size_x, + int& tile_size_y, + float& tile_overlap, + const sd_tiling_params_t& params, + int latent_x, + int latent_y, + float encoding_factor = 1.0f) { + tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f); + auto get_tile_size = [&](int requested_size, float factor, int latent_size) { + const int default_tile_size = 32; + const int min_tile_dimension = 4; + int tile_size = default_tile_size; + // factor <= 1 means simple fraction of the latent dimension + // factor > 1 means number of tiles across that dimension + if (factor > 0.f) { + if (factor > 1.0) + factor = 1 / (factor - factor * tile_overlap + tile_overlap); + tile_size = std::round(latent_size * factor); + } else if (requested_size >= min_tile_dimension) { + tile_size = requested_size; + } + tile_size *= encoding_factor; + return std::max(std::min(tile_size, latent_size), min_tile_dimension); + }; + + tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x); + tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y); + } + + ggml_tensor* encode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool encode_video = false) { int64_t t0 = ggml_time_ms(); ggml_tensor* result = NULL; + int W = x->ne[0] / 8; + int H = x->ne[1] / 8; + if (vae_tiling_params.enabled && !encode_video) { + // TODO wan2.2 vae support? + int C = sd_version_is_dit(version) ? 16 : 4; + if (!use_tiny_autoencoder) { + C *= 2; + } + result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, x->ne[3]); + } + if (!use_tiny_autoencoder) { + float tile_overlap; + int tile_size_x, tile_size_y; + // multiply tile size for encode to keep the compute buffer size consistent + get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, W, H, 1.30539f); + + LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y); + process_vae_input_tensor(x); - first_stage_model->compute(n_threads, x, false, &result, work_ctx); + if (vae_tiling_params.enabled && !encode_video) { + auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + first_stage_model->compute(n_threads, in, false, &out, work_ctx); + }; + sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling); + } else { + first_stage_model->compute(n_threads, x, false, &result, work_ctx); + } first_stage_model->free_compute_buffer(); } else { - tae_first_stage->compute(n_threads, x, false, &result, work_ctx); + if (vae_tiling_params.enabled && !encode_video) { + // split latent in 32x32 tiles and compute in several steps + auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { + tae_first_stage->compute(n_threads, in, false, &out, NULL); + }; + sd_tiling(x, result, 8, 64, 0.5f, on_tiling); + } else { + tae_first_stage->compute(n_threads, x, false, &result, work_ctx); + } tae_first_stage->free_compute_buffer(); } @@ -1426,24 +1466,29 @@ class StableDiffusionGGML { C, x->ne[3]); } - int64_t t0 = ggml_time_ms(); if (!use_tiny_autoencoder) { + float tile_overlap; + int tile_size_x, tile_size_y; + get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, vae_tiling_params, x->ne[0], x->ne[1]); + + LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y); + process_latent_out(x); // x = load_tensor_from_file(work_ctx, "wan_vae_z.bin"); - if (vae_tiling && !decode_video) { + if (vae_tiling_params.enabled && !decode_video) { // split latent in 32x32 tiles and compute in several steps auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { first_stage_model->compute(n_threads, in, true, &out, NULL); }; - sd_tiling(x, result, 8, 32, 0.5f, on_tiling); + sd_tiling_non_square(x, result, 8, tile_size_x, tile_size_y, tile_overlap, on_tiling); } else { first_stage_model->compute(n_threads, x, true, &result, work_ctx); } first_stage_model->free_compute_buffer(); process_vae_output_tensor(result); } else { - if (vae_tiling && !decode_video) { + if (vae_tiling_params.enabled && !decode_video) { // split latent in 64x64 tiles and compute in several steps auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { tae_first_stage->compute(n_threads, in, true, &out); @@ -1467,11 +1512,14 @@ class StableDiffusionGGML { #define NONE_STR "NONE" const char* sd_type_name(enum sd_type_t type) { - return ggml_type_name((ggml_type)type); + if ((int)type < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT)) { + return ggml_type_name((ggml_type)type); + } + return NONE_STR; } enum sd_type_t str_to_sd_type(const char* str) { - for (int i = 0; i < SD_TYPE_COUNT; i++) { + for (int i = 0; i < std::min(SD_TYPE_COUNT, GGML_TYPE_COUNT); i++) { auto trait = ggml_get_type_traits((ggml_type)i); if (!strcmp(str, trait->type_name)) { return (enum sd_type_t)i; @@ -1502,7 +1550,7 @@ enum rng_type_t str_to_rng_type(const char* str) { } const char* sample_method_to_str[] = { - "euler_a", + "default", "euler", "heun", "dpm2", @@ -1514,6 +1562,7 @@ const char* sample_method_to_str[] = { "lcm", "ddim_trailing", "tcd", + "euler_a", }; const char* sd_sample_method_name(enum sample_method_t sample_method) { @@ -1561,7 +1610,6 @@ enum scheduler_t str_to_schedule(const char* str) { void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { *sd_ctx_params = {}; sd_ctx_params->vae_decode_only = true; - sd_ctx_params->vae_tiling = false; sd_ctx_params->free_params_immediately = true; sd_ctx_params->n_threads = get_num_physical_cores(); sd_ctx_params->wtype = SD_TYPE_COUNT; @@ -1596,7 +1644,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "control_net_path: %s\n" "lora_model_dir: %s\n" "embedding_dir: %s\n" - "stacked_id_embed_dir: %s\n" + "photo_maker_path: %s\n" "vae_decode_only: %s\n" "vae_tiling: %s\n" "free_params_immediately: %s\n" @@ -1623,9 +1671,8 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { SAFE_STR(sd_ctx_params->control_net_path), SAFE_STR(sd_ctx_params->lora_model_dir), SAFE_STR(sd_ctx_params->embedding_dir), - SAFE_STR(sd_ctx_params->stacked_id_embed_dir), + SAFE_STR(sd_ctx_params->photo_maker_path), BOOL_STR(sd_ctx_params->vae_decode_only), - BOOL_STR(sd_ctx_params->vae_tiling), BOOL_STR(sd_ctx_params->free_params_immediately), sd_ctx_params->n_threads, sd_type_name(sd_ctx_params->wtype), @@ -1652,7 +1699,7 @@ void sd_sample_params_init(sd_sample_params_t* sample_params) { sample_params->guidance.slg.layer_end = 0.2f; sample_params->guidance.slg.scale = 0.f; sample_params->scheduler = DEFAULT; - sample_params->sample_method = EULER_A; + sample_params->sample_method = SAMPLE_METHOD_DEFAULT; sample_params->sample_steps = 20; } @@ -1692,16 +1739,17 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) { void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) { *sd_img_gen_params = {}; sd_sample_params_init(&sd_img_gen_params->sample_params); - sd_img_gen_params->clip_skip = -1; - sd_img_gen_params->ref_images_count = 0; - sd_img_gen_params->width = 512; - sd_img_gen_params->height = 512; - sd_img_gen_params->strength = 0.75f; - sd_img_gen_params->seed = -1; - sd_img_gen_params->batch_count = 1; - sd_img_gen_params->control_strength = 0.9f; - sd_img_gen_params->style_strength = 20.f; - sd_img_gen_params->normalize_input = false; + sd_img_gen_params->clip_skip = -1; + sd_img_gen_params->ref_images_count = 0; + sd_img_gen_params->width = 512; + sd_img_gen_params->height = 512; + sd_img_gen_params->strength = 0.75f; + sd_img_gen_params->seed = -1; + sd_img_gen_params->batch_count = 1; + sd_img_gen_params->control_strength = 0.9f; + sd_img_gen_params->normalize_input = false; + sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f}; + sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; } char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { @@ -1721,14 +1769,13 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { "sample_params: %s\n" "strength: %.2f\n" "seed: %" PRId64 - "\n" "batch_count: %d\n" "ref_images_count: %d\n" "increase_ref_index: %s\n" "control_strength: %.2f\n" - "style_strength: %.2f\n" "normalize_input: %s\n" - "input_id_images_path: %s\n", + "photo maker: {style_strength = %.2f, id_images_count = %d, id_embed_path = %s}\n" + "VAE tiling: %s\n", SAFE_STR(sd_img_gen_params->prompt), SAFE_STR(sd_img_gen_params->negative_prompt), sd_img_gen_params->clip_skip, @@ -1741,9 +1788,11 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { sd_img_gen_params->ref_images_count, BOOL_STR(sd_img_gen_params->increase_ref_index), sd_img_gen_params->control_strength, - sd_img_gen_params->style_strength, BOOL_STR(sd_img_gen_params->normalize_input), - SAFE_STR(sd_img_gen_params->input_id_images_path)); + sd_img_gen_params->pm_params.style_strength, + sd_img_gen_params->pm_params.id_images_count, + SAFE_STR(sd_img_gen_params->pm_params.id_embed_path), + BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled)); free(sample_params_str); return buf; } @@ -1759,6 +1808,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { sd_vid_gen_params->seed = -1; sd_vid_gen_params->video_frames = 6; sd_vid_gen_params->moe_boundary = 0.875f; + sd_vid_gen_params->vace_strength = 1.f; } struct sd_ctx_t { @@ -1794,6 +1844,17 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) { free(sd_ctx); } +enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) { + if (sd_ctx != NULL && sd_ctx->sd != NULL) { + SDVersion version = sd_ctx->sd->version; + if (sd_version_is_dit(version)) + return EULER; + else + return EULER_A; + } + return SAMPLE_METHOD_COUNT; +} + sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, struct ggml_context* work_ctx, ggml_tensor* init_latent, @@ -1810,9 +1871,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, int batch_count, sd_image_t control_image, float control_strength, - float style_ratio, bool normalize_input, - std::string input_id_images_path, + sd_pm_params_t pm_params, std::vector ref_latents, bool increase_ref_index, ggml_tensor* concat_latent = NULL, @@ -1853,67 +1913,46 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, } } // preprocess input id images - std::vector input_id_images; bool pmv2 = sd_ctx->sd->pmid_model->get_version() == PM_VERSION_2; - if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) { - std::vector img_files = get_files_from_dir(input_id_images_path); - for (std::string img_file : img_files) { - int c = 0; - int width, height; - if (ends_with(img_file, "safetensors")) { - continue; - } - uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3); - if (input_image_buffer == NULL) { - LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str()); - continue; - } else { - LOG_INFO("PhotoMaker loaded image from '%s'", img_file.c_str()); - } - sd_image_t* input_image = NULL; - input_image = new sd_image_t{(uint32_t)width, - (uint32_t)height, - 3, - input_image_buffer}; - input_image = preprocess_id_image(input_image); - if (input_image == NULL) { - LOG_ERROR("preprocess input id image from '%s' failed", img_file.c_str()); - continue; - } - input_id_images.push_back(input_image); - } - } - if (input_id_images.size() > 0) { - sd_ctx->sd->pmid_model->style_strength = style_ratio; - int32_t w = input_id_images[0]->width; - int32_t h = input_id_images[0]->height; - int32_t channels = input_id_images[0]->channel; - int32_t num_input_images = (int32_t)input_id_images.size(); - init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, w, h, channels, num_input_images); - // TODO: move these to somewhere else and be user settable - float mean[] = {0.48145466f, 0.4578275f, 0.40821073f}; - float std[] = {0.26862954f, 0.26130258f, 0.27577711f}; - for (int i = 0; i < num_input_images; i++) { - sd_image_t* init_image = input_id_images[i]; - if (normalize_input) - sd_mul_images_to_tensor(init_image->data, init_img, i, mean, std); - else - sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL); + if (pm_params.id_images_count > 0) { + int clip_image_size = 224; + sd_ctx->sd->pmid_model->style_strength = pm_params.style_strength; + + init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, clip_image_size, clip_image_size, 3, pm_params.id_images_count); + + std::vector processed_id_images; + for (int i = 0; i < pm_params.id_images_count; i++) { + sd_image_f32_t id_image = sd_image_t_to_sd_image_f32_t(pm_params.id_images[i]); + sd_image_f32_t processed_id_image = clip_preprocess(id_image, clip_image_size); + free(id_image.data); + id_image.data = NULL; + processed_id_images.push_back(processed_id_image); + } + + ggml_tensor_iter(init_img, [&](ggml_tensor* init_img, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = sd_image_get_f32(processed_id_images[i3], i0, i1, i2); + ggml_tensor_set_f32(init_img, value, i0, i1, i2, i3); + }); + + for (auto& image : processed_id_images) { + free(image.data); + image.data = NULL; } + processed_id_images.clear(); + int64_t t0 = ggml_time_ms(); auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx, sd_ctx->sd->n_threads, prompt, clip_skip, width, height, - num_input_images, + pm_params.id_images_count, sd_ctx->sd->diffusion_model->get_adm_in_channels()); id_cond = std::get<0>(cond_tup); class_tokens_mask = std::get<1>(cond_tup); // struct ggml_tensor* id_embeds = NULL; - if (pmv2) { - // id_embeds = sd_ctx->sd->pmid_id_embeds->get(); - id_embeds = load_tensor_from_file(work_ctx, path_join(input_id_images_path, "id_embeds.bin")); + if (pmv2 && pm_params.id_embed_path != nullptr) { + id_embeds = load_tensor_from_file(work_ctx, pm_params.id_embed_path); // print_ggml_tensor(id_embeds, true, "id_embeds:"); } id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask); @@ -1926,19 +1965,14 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, prompt_text_only = sd_ctx->sd->cond_stage_model->remove_trigger_from_prompt(work_ctx, prompt); // printf("%s || %s \n", prompt.c_str(), prompt_text_only.c_str()); prompt = prompt_text_only; // - // if (sample_steps < 50) { - // LOG_INFO("sampling steps increases from %d to 50 for PHOTOMAKER", sample_steps); - // sample_steps = 50; - // } + if (sample_steps < 50) { + LOG_WARN("It's recommended to use >= 50 steps for photo maker!"); + } } else { LOG_WARN("Provided PhotoMaker model file, but NO input ID images"); LOG_WARN("Turn off PhotoMaker"); sd_ctx->sd->stacked_id = false; } - for (sd_image_t* img : input_id_images) { - free(img->data); - } - input_id_images.clear(); } // Get learned condition @@ -1978,7 +2012,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, struct ggml_tensor* image_hint = NULL; if (control_image.data != NULL) { image_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); - sd_image_to_tensor(control_image.data, image_hint); + sd_image_to_tensor(control_image, image_hint); } // Sample @@ -2162,8 +2196,9 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx, } sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) { - int width = sd_img_gen_params->width; - int height = sd_img_gen_params->height; + sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params; + int width = sd_img_gen_params->width; + int height = sd_img_gen_params->height; if (sd_version_is_dit(sd_ctx->sd->version)) { if (width % 16 || height % 16) { LOG_ERROR("Image dimensions must be must be a multiple of 16 on each axis for %s models. (Got %dx%d)", @@ -2185,19 +2220,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g } struct ggml_init_params params; - params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_version_is_sd3(sd_ctx->sd->version)) { - params.mem_size *= 3; - } - if (sd_version_is_flux(sd_ctx->sd->version)) { - params.mem_size *= 4; - } - if (sd_ctx->sd->stacked_id) { - params.mem_size += static_cast(10 * 1024 * 1024); // 10 MB - } - params.mem_size += width * height * 3 * sizeof(float) * 3; - params.mem_size += width * height * 3 * sizeof(float) * 3 * sd_img_gen_params->ref_images_count; - params.mem_size *= sd_img_gen_params->batch_count; + params.mem_size = static_cast(1024 * 1024) * 1024; // 1G params.mem_buffer = NULL; params.no_alloc = false; // LOG_DEBUG("mem_size %u ", params.mem_size); @@ -2239,8 +2262,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); ggml_tensor* mask_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 1, 1); - sd_mask_to_tensor(sd_img_gen_params->mask_image.data, mask_img); - sd_image_to_tensor(sd_img_gen_params->init_image.data, init_img); + sd_image_to_tensor(sd_img_gen_params->mask_image, mask_img); + sd_image_to_tensor(sd_img_gen_params->init_image, init_img); if (sd_version_is_inpaint(sd_ctx->sd->version)) { int64_t mask_channels = 1; @@ -2331,7 +2354,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_img_gen_params->ref_images[i].height, 3, 1); - sd_image_to_tensor(sd_img_gen_params->ref_images[i].data, img); + sd_image_to_tensor(sd_img_gen_params->ref_images[i], img); ggml_tensor* latent = NULL; if (sd_ctx->sd->use_tiny_autoencoder) { @@ -2358,6 +2381,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); } + enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method; + if (sample_method == SAMPLE_METHOD_DEFAULT) { + sample_method = sd_get_default_sample_method(sd_ctx); + } + sd_image_t* result_images = generate_image_internal(sd_ctx, work_ctx, init_latent, @@ -2368,15 +2396,14 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_img_gen_params->sample_params.eta, width, height, - sd_img_gen_params->sample_params.sample_method, + sample_method, sigmas, seed, sd_img_gen_params->batch_count, sd_img_gen_params->control_image, sd_img_gen_params->control_strength, - sd_img_gen_params->style_strength, sd_img_gen_params->normalize_input, - SAFE_STR(sd_img_gen_params->input_id_images_path), + sd_img_gen_params->pm_params, ref_latents, sd_img_gen_params->increase_ref_index, concat_latent, @@ -2432,8 +2459,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s } struct ggml_init_params params; - params.mem_size = static_cast(200 * 1024) * 1024; // 200 MB - params.mem_size += width * height * frames * 3 * sizeof(float) * 2; + params.mem_size = static_cast(1024 * 1024) * 1024; // 1G params.mem_buffer = NULL; params.no_alloc = false; // LOG_DEBUG("mem_size %u ", params.mem_size); @@ -2460,6 +2486,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s ggml_tensor* clip_vision_output = NULL; ggml_tensor* concat_latent = NULL; ggml_tensor* denoise_mask = NULL; + ggml_tensor* vace_context = NULL; + int64_t ref_image_num = 0; // for vace if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" || sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B" || sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") { @@ -2489,23 +2517,17 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s int64_t t1 = ggml_time_ms(); ggml_tensor* image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3); - for (int i3 = 0; i3 < image->ne[3]; i3++) { // channels - for (int i2 = 0; i2 < image->ne[2]; i2++) { - for (int i1 = 0; i1 < image->ne[1]; i1++) { // height - for (int i0 = 0; i0 < image->ne[0]; i0++) { // width - float value = 0.5f; - if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image - value = *(sd_vid_gen_params->init_image.data + i1 * width * 3 + i0 * 3 + i3); - value /= 255.f; - } else if (i2 == frames - 1 && sd_vid_gen_params->end_image.data) { - value = *(sd_vid_gen_params->end_image.data + i1 * width * 3 + i0 * 3 + i3); - value /= 255.f; - } - ggml_tensor_set_f32(image, value, i0, i1, i2, i3); - } - } - } - } + ggml_tensor_iter(image, [&](ggml_tensor* image, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = 0.5f; + if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image + value = *(sd_vid_gen_params->init_image.data + i1 * width * 3 + i0 * 3 + i3); + value /= 255.f; + } else if (i2 == frames - 1 && sd_vid_gen_params->end_image.data) { + value = *(sd_vid_gen_params->end_image.data + i1 * width * 3 + i0 * 3 + i3); + value /= 255.f; + } + ggml_tensor_set_f32(image, value, i0, i1, i2, i3); + }); concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, image); // [b*c, t, h/8, w/8] @@ -2520,21 +2542,15 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s concat_latent->ne[1], concat_latent->ne[2], 4); // [b*4, t, w/8, h/8] - for (int i3 = 0; i3 < concat_mask->ne[3]; i3++) { - for (int i2 = 0; i2 < concat_mask->ne[2]; i2++) { - for (int i1 = 0; i1 < concat_mask->ne[1]; i1++) { - for (int i0 = 0; i0 < concat_mask->ne[0]; i0++) { - float value = 0.0f; - if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image - value = 1.0f; - } else if (i2 == frames - 1 && sd_vid_gen_params->end_image.data && i3 == 3) { - value = 1.0f; - } - ggml_tensor_set_f32(concat_mask, value, i0, i1, i2, i3); - } - } + ggml_tensor_iter(concat_mask, [&](ggml_tensor* concat_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = 0.0f; + if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image + value = 1.0f; + } else if (i2 == frames - 1 && sd_vid_gen_params->end_image.data && i3 == 3) { + value = 1.0f; } - } + ggml_tensor_set_f32(concat_mask, value, i0, i1, i2, i3); + }); concat_latent = ggml_tensor_concat(work_ctx, concat_mask, concat_latent, 3); // [b*(c+4), t, h/8, w/8] } else if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-TI2V-5B" && sd_vid_gen_params->init_image.data) { @@ -2542,7 +2558,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s int64_t t1 = ggml_time_ms(); ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); - sd_image_to_tensor(sd_vid_gen_params->init_image.data, init_img); + sd_image_to_tensor(sd_vid_gen_params->init_image, init_img); init_img = ggml_reshape_4d(work_ctx, init_img, width, height, 1, 3); auto init_image_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); // [b*c, 1, h/16, w/16] @@ -2553,22 +2569,95 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s sd_ctx->sd->process_latent_out(init_latent); - for (int i3 = 0; i3 < init_image_latent->ne[3]; i3++) { - for (int i2 = 0; i2 < init_image_latent->ne[2]; i2++) { - for (int i1 = 0; i1 < init_image_latent->ne[1]; i1++) { - for (int i0 = 0; i0 < init_image_latent->ne[0]; i0++) { - float value = ggml_tensor_get_f32(init_image_latent, i0, i1, i2, i3); - ggml_tensor_set_f32(init_latent, value, i0, i1, i2, i3); - if (i3 == 0) { - ggml_tensor_set_f32(denoise_mask, 0.f, i0, i1, i2, i3); - } - } - } + ggml_tensor_iter(init_image_latent, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = ggml_tensor_get_f32(t, i0, i1, i2, i3); + ggml_tensor_set_f32(init_latent, value, i0, i1, i2, i3); + if (i3 == 0) { + ggml_tensor_set_f32(denoise_mask, 0.f, i0, i1, i2, i3); } - } + }); sd_ctx->sd->process_latent_in(init_latent); + int64_t t2 = ggml_time_ms(); + LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1); + } else if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-VACE-1.3B" || + sd_ctx->sd->diffusion_model->get_desc() == "Wan2.x-VACE-14B") { + LOG_INFO("VACE"); + int64_t t1 = ggml_time_ms(); + ggml_tensor* ref_image_latent = NULL; + if (sd_vid_gen_params->init_image.data) { + ggml_tensor* ref_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + sd_image_to_tensor(sd_vid_gen_params->init_image, ref_img); + ref_img = ggml_reshape_4d(work_ctx, ref_img, width, height, 1, 3); + + ref_image_latent = sd_ctx->sd->encode_first_stage(work_ctx, ref_img); // [b*c, 1, h/16, w/16] + sd_ctx->sd->process_latent_in(ref_image_latent); + auto zero_latent = ggml_dup_tensor(work_ctx, ref_image_latent); + ggml_set_f32(zero_latent, 0.f); + ref_image_latent = ggml_tensor_concat(work_ctx, ref_image_latent, zero_latent, 3); // [b*2*c, 1, h/16, w/16] + } + + ggml_tensor* control_video = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3); + ggml_tensor_iter(control_video, [&](ggml_tensor* control_video, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = 0.5f; + if (i2 < sd_vid_gen_params->control_frames_size) { + value = sd_image_get_f32(sd_vid_gen_params->control_frames[i2], i0, i1, i3); + } + ggml_tensor_set_f32(control_video, value, i0, i1, i2, i3); + }); + ggml_tensor* mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 1); + ggml_set_f32(mask, 1.0f); + ggml_tensor* inactive = ggml_dup_tensor(work_ctx, control_video); + ggml_tensor* reactive = ggml_dup_tensor(work_ctx, control_video); + + ggml_tensor_iter(control_video, [&](ggml_tensor* t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float control_video_value = ggml_tensor_get_f32(t, i0, i1, i2, i3) - 0.5f; + float mask_value = ggml_tensor_get_f32(mask, i0, i1, i2, 0); + float inactive_value = (control_video_value * (1.f - mask_value)) + 0.5f; + float reactive_value = (control_video_value * mask_value) + 0.5f; + + ggml_tensor_set_f32(inactive, inactive_value, i0, i1, i2, i3); + ggml_tensor_set_f32(reactive, reactive_value, i0, i1, i2, i3); + }); + + inactive = sd_ctx->sd->encode_first_stage(work_ctx, inactive); // [b*c, t, h/8, w/8] + reactive = sd_ctx->sd->encode_first_stage(work_ctx, reactive); // [b*c, t, h/8, w/8] + + sd_ctx->sd->process_latent_in(inactive); + sd_ctx->sd->process_latent_in(reactive); + + int64_t length = inactive->ne[2]; + if (ref_image_latent) { + length += 1; + frames = (length - 1) * 4 + 1; + ref_image_num = 1; + } + vace_context = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, inactive->ne[0], inactive->ne[1], length, 96); // [b*96, t, h/8, w/8] + ggml_tensor_iter(vace_context, [&](ggml_tensor* vace_context, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value; + if (i3 < 32) { + if (ref_image_latent && i2 == 0) { + value = ggml_tensor_get_f32(ref_image_latent, i0, i1, 0, i3); + } else { + if (i3 < 16) { + value = ggml_tensor_get_f32(inactive, i0, i1, i2 - ref_image_num, i3); + } else { + value = ggml_tensor_get_f32(reactive, i0, i1, i2 - ref_image_num, i3 - 16); + } + } + } else { // mask + if (ref_image_latent && i2 == 0) { + value = 0.f; + } else { + int64_t vae_stride = 8; + int64_t mask_height_index = i1 * vae_stride + (i3 - 32) / vae_stride; + int64_t mask_width_index = i0 * vae_stride + (i3 - 32) % vae_stride; + value = ggml_tensor_get_f32(mask, mask_width_index, mask_height_index, i2 - ref_image_num, 0); + } + } + ggml_tensor_set_f32(vace_context, value, i0, i1, i2, i3); + }); int64_t t2 = ggml_time_ms(); LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1); } @@ -2650,7 +2739,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s -1, {}, {}, - denoise_mask); + false, + denoise_mask, + vace_context, + sd_vid_gen_params->vace_strength); int64_t sampling_end = ggml_time_ms(); LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); @@ -2682,7 +2774,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s -1, {}, {}, - denoise_mask); + false, + denoise_mask, + vace_context, + sd_vid_gen_params->vace_strength); int64_t sampling_end = ggml_time_ms(); LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); @@ -2691,6 +2786,20 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s } } + if (ref_image_num > 0) { + ggml_tensor* trim_latent = ggml_new_tensor_4d(work_ctx, + GGML_TYPE_F32, + final_latent->ne[0], + final_latent->ne[1], + final_latent->ne[2] - ref_image_num, + final_latent->ne[3]); + ggml_tensor_iter(trim_latent, [&](ggml_tensor* trim_latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = ggml_tensor_get_f32(final_latent, i0, i1, i2 + ref_image_num, i3); + ggml_tensor_set_f32(trim_latent, value, i0, i1, i2, i3); + }); + final_latent = trim_latent; + } + int64_t t4 = ggml_time_ms(); LOG_INFO("generating latent video completed, taking %.2fs", (t4 - t2) * 1.0f / 1000); struct ggml_tensor* vid = sd_ctx->sd->decode_first_stage(work_ctx, final_latent, true); diff --git a/stable-diffusion.h b/stable-diffusion.h index 0f47a763..d1c3c717 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -35,7 +35,7 @@ enum rng_type_t { }; enum sample_method_t { - EULER_A, + SAMPLE_METHOD_DEFAULT, EULER, HEUN, DPM2, @@ -47,6 +47,7 @@ enum sample_method_t { LCM, DDIM_TRAILING, TCD, + EULER_A, SAMPLE_METHOD_COUNT }; @@ -113,6 +114,15 @@ enum sd_log_level_t { SD_LOG_ERROR }; +typedef struct { + bool enabled; + int tile_size_x; + int tile_size_y; + float target_overlap; + float rel_size_x; + float rel_size_y; +} sd_tiling_params_t; + typedef struct { const char* model_path; const char* clip_l_path; @@ -126,9 +136,8 @@ typedef struct { const char* control_net_path; const char* lora_model_dir; const char* embedding_dir; - const char* stacked_id_embed_dir; + const char* photo_maker_path; bool vae_decode_only; - bool vae_tiling; bool free_params_immediately; int n_threads; enum sd_type_t wtype; @@ -176,6 +185,13 @@ typedef struct { float eta; } sd_sample_params_t; +typedef struct { + sd_image_t* id_images; + int id_images_count; + const char* id_embed_path; + float style_strength; +} sd_pm_params_t; // photo maker + typedef struct { const char* prompt; const char* negative_prompt; @@ -193,9 +209,9 @@ typedef struct { int batch_count; sd_image_t control_image; float control_strength; - float style_strength; bool normalize_input; - const char* input_id_images_path; + sd_pm_params_t pm_params; + sd_tiling_params_t vae_tiling_params; } sd_img_gen_params_t; typedef struct { @@ -204,6 +220,8 @@ typedef struct { int clip_skip; sd_image_t init_image; sd_image_t end_image; + sd_image_t* control_frames; + int control_frames_size; int width; int height; sd_sample_params_t sample_params; @@ -212,6 +230,7 @@ typedef struct { float strength; int64_t seed; int video_frames; + float vace_strength; } sd_vid_gen_params_t; typedef struct sd_ctx_t sd_ctx_t; @@ -238,6 +257,7 @@ SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params); SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); +SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx); SD_API void sd_sample_params_init(sd_sample_params_t* sample_params); SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params); @@ -267,14 +287,12 @@ SD_API bool convert(const char* input_path, enum sd_type_t output_type, const char* tensor_type_rules); -SD_API uint8_t* preprocess_canny(uint8_t* img, - int width, - int height, - float high_threshold, - float low_threshold, - float weak, - float strong, - bool inverse); +SD_API bool preprocess_canny(sd_image_t image, + float high_threshold, + float low_threshold, + float weak, + float strong, + bool inverse); #ifdef __cplusplus } diff --git a/upscaler.cpp b/upscaler.cpp index 2bd62c09..7e765d77 100644 --- a/upscaler.cpp +++ b/upscaler.cpp @@ -69,8 +69,7 @@ struct UpscalerGGML { input_image.width, input_image.height, output_width, output_height); struct ggml_init_params params; - params.mem_size = output_width * output_height * 3 * sizeof(float) * 2; - params.mem_size += 2 * ggml_tensor_overhead(); + params.mem_size = static_cast(1024 * 1024) * 1024; // 1G params.mem_buffer = NULL; params.no_alloc = false; @@ -80,9 +79,9 @@ struct UpscalerGGML { LOG_ERROR("ggml_init() failed"); return upscaled_image; } - LOG_DEBUG("upscale work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f); + // LOG_DEBUG("upscale work buffer size: %.2f MB", params.mem_size / 1024.f / 1024.f); ggml_tensor* input_image_tensor = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, input_image.width, input_image.height, 3, 1); - sd_image_to_tensor(input_image.data, input_image_tensor); + sd_image_to_tensor(input_image, input_image_tensor); ggml_tensor* upscaled = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, output_width, output_height, 3, 1); auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { diff --git a/util.cpp b/util.cpp index b9142e60..5af6b1ec 100644 --- a/util.cpp +++ b/util.cpp @@ -110,56 +110,6 @@ std::string get_full_path(const std::string& dir, const std::string& filename) { } } -std::vector get_files_from_dir(const std::string& dir) { - std::vector files; - - WIN32_FIND_DATA findFileData; - HANDLE hFind; - - char currentDirectory[MAX_PATH]; - GetCurrentDirectory(MAX_PATH, currentDirectory); - - char directoryPath[MAX_PATH]; // this is absolute path - sprintf(directoryPath, "%s\\%s\\*", currentDirectory, dir.c_str()); - - // Find the first file in the directory - hFind = FindFirstFile(directoryPath, &findFileData); - bool isAbsolutePath = false; - // Check if the directory was found - if (hFind == INVALID_HANDLE_VALUE) { - printf("Unable to find directory. Try with original path \n"); - - char directoryPathAbsolute[MAX_PATH]; - sprintf(directoryPathAbsolute, "%s*", dir.c_str()); - - hFind = FindFirstFile(directoryPathAbsolute, &findFileData); - isAbsolutePath = true; - if (hFind == INVALID_HANDLE_VALUE) { - printf("Absolute path was also wrong.\n"); - return files; - } - } - - // Loop through all files in the directory - do { - // Check if the found file is a regular file (not a directory) - if (!(findFileData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) { - if (isAbsolutePath) { - files.push_back(dir + "\\" + std::string(findFileData.cFileName)); - } else { - files.push_back(std::string(currentDirectory) + "\\" + dir + "\\" + std::string(findFileData.cFileName)); - } - } - } while (FindNextFile(hFind, &findFileData) != 0); - - // Close the handle - FindClose(hFind); - - sort(files.begin(), files.end()); - - return files; -} - #else // Unix #include #include @@ -194,27 +144,6 @@ std::string get_full_path(const std::string& dir, const std::string& filename) { return ""; } -std::vector get_files_from_dir(const std::string& dir) { - std::vector files; - - DIR* dp = opendir(dir.c_str()); - - if (dp != nullptr) { - struct dirent* entry; - - while ((entry = readdir(dp)) != nullptr) { - std::string fname = dir + "/" + entry->d_name; - if (!is_directory(fname)) - files.push_back(fname); - } - closedir(dp); - } - - sort(files.begin(), files.end()); - - return files; -} - #endif // get_num_physical_cores is copy from @@ -318,39 +247,6 @@ std::vector split_string(const std::string& str, char delimiter) { return result; } -sd_image_t* preprocess_id_image(sd_image_t* img) { - int shortest_edge = 224; - int size = shortest_edge; - sd_image_t* resized = NULL; - uint32_t w = img->width; - uint32_t h = img->height; - uint32_t c = img->channel; - - // 1. do resize using stb_resize functions - - unsigned char* buf = (unsigned char*)malloc(sizeof(unsigned char) * 3 * size * size); - if (!stbir_resize_uint8(img->data, w, h, 0, - buf, size, size, 0, - c)) { - fprintf(stderr, "%s: resize operation failed \n ", __func__); - return resized; - } - - // 2. do center crop (likely unnecessary due to step 1) - - // 3. do rescale - - // 4. do normalize - - // 3 and 4 will need to be done in float format. - - resized = new sd_image_t{(uint32_t)shortest_edge, - (uint32_t)shortest_edge, - 3, - buf}; - return resized; -} - void pretty_progress(int step, int steps, float time) { if (sd_progress_cb) { sd_progress_cb(step, steps, time, sd_progress_cb_data); diff --git a/util.h b/util.h index 89a990c8..1e8db6e3 100644 --- a/util.h +++ b/util.h @@ -24,14 +24,9 @@ bool file_exists(const std::string& filename); bool is_directory(const std::string& path); std::string get_full_path(const std::string& dir, const std::string& filename); -std::vector get_files_from_dir(const std::string& dir); - std::u32string utf8_to_utf32(const std::string& utf8_str); std::string utf32_to_utf8(const std::u32string& utf32_str); std::u32string unicode_value_to_utf32(int unicode_value); - -sd_image_t* preprocess_id_image(sd_image_t* img); - // std::string sd_basename(const std::string& path); typedef struct { diff --git a/vae.hpp b/vae.hpp index 408d32d6..dd982ab7 100644 --- a/vae.hpp +++ b/vae.hpp @@ -588,7 +588,7 @@ struct AutoEncoderKL : public VAE { }; // ggml_set_f32(z, 0.5f); // print_ggml_tensor(z); - GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); + GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); } void test() { diff --git a/wan.hpp b/wan.hpp index 48603a95..7e3510a1 100644 --- a/wan.hpp +++ b/wan.hpp @@ -1219,7 +1219,7 @@ namespace WAN { void test() { struct ggml_init_params params; - params.mem_size = static_cast(1000 * 1024 * 1024); // 10 MB + params.mem_size = static_cast(1024 * 1024) * 1024; // 1G params.mem_buffer = NULL; params.no_alloc = false; @@ -1532,13 +1532,13 @@ namespace WAN { blocks["ffn.2"] = std::shared_ptr(new Linear(ffn_dim, dim)); } - struct ggml_tensor* forward(struct ggml_context* ctx, - ggml_backend_t backend, - struct ggml_tensor* x, - struct ggml_tensor* e, - struct ggml_tensor* pe, - struct ggml_tensor* context, - int64_t context_img_len = 257) { + virtual struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + struct ggml_tensor* e, + struct ggml_tensor* pe, + struct ggml_tensor* context, + int64_t context_img_len = 257) { // x: [N, n_token, dim] // e: [N, 6, dim] or [N, T, 6, dim] // context: [N, context_img_len + context_txt_len, dim] @@ -1584,6 +1584,59 @@ namespace WAN { } }; + class VaceWanAttentionBlock : public WanAttentionBlock { + protected: + int block_id; + void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { + enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); + params["modulation"] = ggml_new_tensor_3d(ctx, wtype, dim, 6, 1); + } + + public: + VaceWanAttentionBlock(bool t2v_cross_attn, + int64_t dim, + int64_t ffn_dim, + int64_t num_heads, + bool qk_norm = true, + bool cross_attn_norm = false, + float eps = 1e-6, + int block_id = 0, + bool flash_attn = false) + : WanAttentionBlock(t2v_cross_attn, dim, ffn_dim, num_heads, qk_norm, cross_attn_norm, eps, flash_attn), block_id(block_id) { + if (block_id == 0) { + blocks["before_proj"] = std::shared_ptr(new Linear(dim, dim)); + } + blocks["after_proj"] = std::shared_ptr(new Linear(dim, dim)); + } + + std::pair forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* c, + struct ggml_tensor* x, + struct ggml_tensor* e, + struct ggml_tensor* pe, + struct ggml_tensor* context, + int64_t context_img_len = 257) { + // x: [N, n_token, dim] + // e: [N, 6, dim] or [N, T, 6, dim] + // context: [N, context_img_len + context_txt_len, dim] + // return [N, n_token, dim] + if (block_id == 0) { + auto before_proj = std::dynamic_pointer_cast(blocks["before_proj"]); + + c = before_proj->forward(ctx, c); + c = ggml_add(ctx, c, x); + } + + auto after_proj = std::dynamic_pointer_cast(blocks["after_proj"]); + + c = WanAttentionBlock::forward(ctx, backend, c, e, pe, context, context_img_len); + auto c_skip = after_proj->forward(ctx, c); + + return {c_skip, c}; + } + }; + class Head : public GGMLBlock { protected: int dim; @@ -1680,22 +1733,25 @@ namespace WAN { }; struct WanParams { - std::string model_type = "t2v"; - std::tuple patch_size = {1, 2, 2}; - int64_t text_len = 512; - int64_t in_dim = 16; - int64_t dim = 2048; - int64_t ffn_dim = 8192; - int64_t freq_dim = 256; - int64_t text_dim = 4096; - int64_t out_dim = 16; - int64_t num_heads = 16; - int64_t num_layers = 32; - bool qk_norm = true; - bool cross_attn_norm = true; - float eps = 1e-6; - int64_t flf_pos_embed_token_number = 0; - int theta = 10000; + std::string model_type = "t2v"; + std::tuple patch_size = {1, 2, 2}; + int64_t text_len = 512; + int64_t in_dim = 16; + int64_t dim = 2048; + int64_t ffn_dim = 8192; + int64_t freq_dim = 256; + int64_t text_dim = 4096; + int64_t out_dim = 16; + int64_t num_heads = 16; + int64_t num_layers = 32; + int64_t vace_layers = 0; + int64_t vace_in_dim = 96; + std::map vace_layers_mapping = {}; + bool qk_norm = true; + bool cross_attn_norm = true; + float eps = 1e-6; + int64_t flf_pos_embed_token_number = 0; + int theta = 10000; // wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24 std::vector axes_dim = {44, 42, 42}; int64_t axes_dim_sum = 128; @@ -1746,6 +1802,31 @@ namespace WAN { if (params.model_type == "i2v") { blocks["img_emb"] = std::shared_ptr(new MLPProj(1280, params.dim, params.flf_pos_embed_token_number)); } + + // vace + if (params.vace_layers > 0) { + for (int i = 0; i < params.vace_layers; i++) { + auto block = std::shared_ptr(new VaceWanAttentionBlock(params.model_type == "t2v", + params.dim, + params.ffn_dim, + params.num_heads, + params.qk_norm, + params.cross_attn_norm, + params.eps, + i, + params.flash_attn)); + blocks["vace_blocks." + std::to_string(i)] = block; + } + + int step = params.num_layers / params.vace_layers; + int n = 0; + for (int i = 0; i < params.num_layers; i += step) { + this->params.vace_layers_mapping[i] = n; + n++; + } + + blocks["vace_patch_embedding"] = std::shared_ptr(new Conv3d(params.vace_in_dim, params.dim, params.patch_size, params.patch_size)); + } } struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx, @@ -1795,9 +1876,12 @@ namespace WAN { struct ggml_tensor* timestep, struct ggml_tensor* context, struct ggml_tensor* pe, - struct ggml_tensor* clip_fea = NULL, - int64_t N = 1) { + struct ggml_tensor* clip_fea = NULL, + struct ggml_tensor* vace_context = NULL, + float vace_strength = 1.f, + int64_t N = 1) { // x: [N*C, T, H, W], C => in_dim + // vace_context: [N*vace_in_dim, T, H, W] // timestep: [N,] or [T] // context: [N, L, text_dim] // return: [N, t_len*h_len*w_len, out_dim*pt*ph*pw] @@ -1845,10 +1929,35 @@ namespace WAN { context_img_len = clip_fea->ne[1]; // 257 } + // vace_patch_embedding + ggml_tensor* c = NULL; + if (params.vace_layers > 0) { + auto vace_patch_embedding = std::dynamic_pointer_cast(blocks["vace_patch_embedding"]); + + c = vace_patch_embedding->forward(ctx, vace_context); // [N*dim, t_len, h_len, w_len] + c = ggml_reshape_3d(ctx, c, c->ne[0] * c->ne[1] * c->ne[2], c->ne[3] / N, N); // [N, dim, t_len*h_len*w_len] + c = ggml_nn_cont(ctx, ggml_torch_permute(ctx, c, 1, 0, 2, 3)); // [N, t_len*h_len*w_len, dim] + } + + auto x_orig = x; + for (int i = 0; i < params.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["blocks." + std::to_string(i)]); x = block->forward(ctx, backend, x, e0, pe, context, context_img_len); + + auto iter = params.vace_layers_mapping.find(i); + if (iter != params.vace_layers_mapping.end()) { + int n = iter->second; + + auto vace_block = std::dynamic_pointer_cast(blocks["vace_blocks." + std::to_string(n)]); + + auto result = vace_block->forward(ctx, backend, c, x_orig, e0, pe, context, context_img_len); + auto c_skip = result.first; + c = result.second; + c_skip = ggml_scale(ctx, c_skip, vace_strength); + x = ggml_add(ctx, x, c_skip); + } } x = head->forward(ctx, x, e); // [N, t_len*h_len*w_len, pt*ph*pw*out_dim] @@ -1864,6 +1973,8 @@ namespace WAN { struct ggml_tensor* pe, struct ggml_tensor* clip_fea = NULL, struct ggml_tensor* time_dim_concat = NULL, + struct ggml_tensor* vace_context = NULL, + float vace_strength = 1.f, int64_t N = 1) { // Forward pass of DiT. // x: [N*C, T, H, W] @@ -1892,7 +2003,7 @@ namespace WAN { t_len = ((x->ne[2] + (std::get<0>(params.patch_size) / 2)) / std::get<0>(params.patch_size)); } - auto out = forward_orig(ctx, backend, x, timestep, context, pe, clip_fea, N); // [N, t_len*h_len*w_len, pt*ph*pw*C] + auto out = forward_orig(ctx, backend, x, timestep, context, pe, clip_fea, vace_context, vace_strength, N); // [N, t_len*h_len*w_len, pt*ph*pw*C] out = unpatchify(ctx, out, t_len, h_len, w_len); // [N*C, (T+pad_t) + (T2+pad_t2), H + pad_h, W + pad_w] @@ -1927,7 +2038,19 @@ namespace WAN { std::string tensor_name = pair.first; if (tensor_name.find(prefix) == std::string::npos) continue; - size_t pos = tensor_name.find("blocks."); + size_t pos = tensor_name.find("vace_blocks."); + if (pos != std::string::npos) { + tensor_name = tensor_name.substr(pos); // remove prefix + auto items = split_string(tensor_name, '.'); + if (items.size() > 1) { + int block_index = atoi(items[1].c_str()); + if (block_index + 1 > wan_params.vace_layers) { + wan_params.vace_layers = block_index + 1; + } + } + continue; + } + pos = tensor_name.find("blocks."); if (pos != std::string::npos) { tensor_name = tensor_name.substr(pos); // remove prefix auto items = split_string(tensor_name, '.'); @@ -1937,6 +2060,7 @@ namespace WAN { wan_params.num_layers = block_index + 1; } } + continue; } if (tensor_name.find("img_emb") != std::string::npos) { wan_params.model_type = "i2v"; @@ -1958,7 +2082,11 @@ namespace WAN { wan_params.out_dim = 48; wan_params.text_len = 512; } else { - desc = "Wan2.1-T2V-1.3B"; + if (wan_params.vace_layers > 0) { + desc = "Wan2.1-VACE-1.3B"; + } else { + desc = "Wan2.1-T2V-1.3B"; + } wan_params.dim = 1536; wan_params.eps = 1e-06; wan_params.ffn_dim = 8960; @@ -1974,7 +2102,11 @@ namespace WAN { desc = "Wan2.2-I2V-14B"; wan_params.in_dim = 36; } else { - desc = "Wan2.x-T2V-14B"; + if (wan_params.vace_layers > 0) { + desc = "Wan2.x-VACE-14B"; + } else { + desc = "Wan2.x-T2V-14B"; + } wan_params.in_dim = 16; } } else { @@ -2015,7 +2147,9 @@ namespace WAN { struct ggml_tensor* context, struct ggml_tensor* clip_fea = NULL, struct ggml_tensor* c_concat = NULL, - struct ggml_tensor* time_dim_concat = NULL) { + struct ggml_tensor* time_dim_concat = NULL, + struct ggml_tensor* vace_context = NULL, + float vace_strength = 1.f) { struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, WAN_GRAPH_SIZE, false); x = to_backend(x); @@ -2024,6 +2158,7 @@ namespace WAN { clip_fea = to_backend(clip_fea); c_concat = to_backend(c_concat); time_dim_concat = to_backend(time_dim_concat); + vace_context = to_backend(vace_context); pe_vec = Rope::gen_wan_pe(x->ne[2], x->ne[1], @@ -2053,7 +2188,9 @@ namespace WAN { context, pe, clip_fea, - time_dim_concat); + time_dim_concat, + vace_context, + vace_strength); ggml_build_forward_expand(gf, out); @@ -2067,10 +2204,12 @@ namespace WAN { struct ggml_tensor* clip_fea = NULL, struct ggml_tensor* c_concat = NULL, struct ggml_tensor* time_dim_concat = NULL, + struct ggml_tensor* vace_context = NULL, + float vace_strength = 1.f, struct ggml_tensor** output = NULL, struct ggml_context* output_ctx = NULL) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, clip_fea, c_concat, time_dim_concat); + return build_graph(x, timesteps, context, clip_fea, c_concat, time_dim_concat, vace_context, vace_strength); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -2108,7 +2247,7 @@ namespace WAN { struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, NULL, NULL, NULL, &out, work_ctx); + compute(8, x, timesteps, context, NULL, NULL, NULL, NULL, 1.f, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out);