From 766b531dcb86331a804fa450e6dfbee9702ca8e3 Mon Sep 17 00:00:00 2001 From: leejet Date: Thu, 11 Sep 2025 01:48:52 +0800 Subject: [PATCH] add sd3 flash attn support --- diffusion_model.hpp | 3 ++- mmdit.hpp | 53 +++++++++++++++++++++++++++----------------- stable-diffusion.cpp | 1 + 3 files changed, 36 insertions(+), 21 deletions(-) diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 89f31b13..138ca67a 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -95,8 +95,9 @@ struct MMDiTModel : public DiffusionModel { MMDiTModel(ggml_backend_t backend, bool offload_params_to_cpu, + bool flash_attn = false, const String2GGMLType& tensor_types = {}) - : mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") { + : mmdit(backend, offload_params_to_cpu, flash_attn, tensor_types, "model.diffusion_model") { } std::string get_desc() { diff --git a/mmdit.hpp b/mmdit.hpp index acb55e60..d9d19340 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -147,14 +147,16 @@ class SelfAttention : public GGMLBlock { int64_t num_heads; bool pre_only; std::string qk_norm; + bool flash_attn; public: SelfAttention(int64_t dim, int64_t num_heads = 8, std::string qk_norm = "", bool qkv_bias = false, - bool pre_only = false) - : num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) { + bool pre_only = false, + bool flash_attn = false) + : num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm), flash_attn(flash_attn) { int64_t d_head = dim / num_heads; blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); if (!pre_only) { @@ -206,8 +208,8 @@ class SelfAttention : public GGMLBlock { ggml_backend_t backend, struct ggml_tensor* x) { auto qkv = pre_attention(ctx, x); - x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, true); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -232,6 +234,7 @@ struct DismantledBlock : public GGMLBlock { int64_t num_heads; bool pre_only; bool self_attn; + bool flash_attn; public: DismantledBlock(int64_t hidden_size, @@ -240,16 +243,17 @@ struct DismantledBlock : public GGMLBlock { std::string qk_norm = "", bool qkv_bias = false, bool pre_only = false, - bool self_attn = false) + bool self_attn = false, + bool flash_attn = false) : num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) { // rmsnorm is always Flase // scale_mod_only is always Flase // swiglu is always Flase blocks["norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); - blocks["attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only)); + blocks["attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only, flash_attn)); if (self_attn) { - blocks["attn2"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false)); + blocks["attn2"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false, flash_attn)); } if (!pre_only) { @@ -435,8 +439,8 @@ struct DismantledBlock : public GGMLBlock { auto qkv2 = std::get<1>(qkv_intermediates); auto intermediates = std::get<2>(qkv_intermediates); - auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] - auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim] + auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] + auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] x = post_attention_x(ctx, attn_out, attn2_out, @@ -452,7 +456,7 @@ struct DismantledBlock : public GGMLBlock { auto qkv = qkv_intermediates.first; auto intermediates = qkv_intermediates.second; - auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] + auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] x = post_attention(ctx, attn_out, intermediates[0], @@ -468,6 +472,7 @@ struct DismantledBlock : public GGMLBlock { __STATIC_INLINE__ std::pair block_mixing(struct ggml_context* ctx, ggml_backend_t backend, + bool flash_attn, struct ggml_tensor* context, struct ggml_tensor* x, struct ggml_tensor* c, @@ -497,8 +502,8 @@ block_mixing(struct ggml_context* ctx, qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1)); } - auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads); // [N, n_context + n_token, hidden_size] - attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size] + auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, NULL, false, false, flash_attn); // [N, n_context + n_token, hidden_size] + attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size] auto context_attn = ggml_view_3d(ctx, attn, attn->ne[0], @@ -556,6 +561,8 @@ block_mixing(struct ggml_context* ctx, } struct JointBlock : public GGMLBlock { + bool flash_attn; + public: JointBlock(int64_t hidden_size, int64_t num_heads, @@ -563,9 +570,11 @@ struct JointBlock : public GGMLBlock { std::string qk_norm = "", bool qkv_bias = false, bool pre_only = false, - bool self_attn_x = false) { - blocks["context_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only)); - blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x)); + bool self_attn_x = false, + bool flash_attn = false) + : flash_attn(flash_attn) { + blocks["context_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only, false, flash_attn)); + blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x, flash_attn)); } std::pair forward(struct ggml_context* ctx, @@ -576,7 +585,7 @@ struct JointBlock : public GGMLBlock { auto context_block = std::dynamic_pointer_cast(blocks["context_block"]); auto x_block = std::dynamic_pointer_cast(blocks["x_block"]); - return block_mixing(ctx, backend, context, x, c, context_block, x_block); + return block_mixing(ctx, backend, flash_attn, context, x, c, context_block, x_block); } }; @@ -634,6 +643,7 @@ struct MMDiT : public GGMLBlock { int64_t context_embedder_out_dim = 1536; int64_t hidden_size; std::string qk_norm; + bool flash_attn = false; void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") { enum ggml_type wtype = GGML_TYPE_F32; @@ -641,7 +651,8 @@ struct MMDiT : public GGMLBlock { } public: - MMDiT(const String2GGMLType& tensor_types = {}) { + MMDiT(bool flash_attn = false, const String2GGMLType& tensor_types = {}) + : flash_attn(flash_attn) { // input_size is always None // learn_sigma is always False // register_length is alwalys 0 @@ -709,7 +720,8 @@ struct MMDiT : public GGMLBlock { qk_norm, true, i == depth - 1, - i <= d_self)); + i <= d_self, + flash_attn)); } blocks["final_layer"] = std::shared_ptr(new FinalLayer(hidden_size, patch_size, out_channels)); @@ -856,9 +868,10 @@ struct MMDiTRunner : public GGMLRunner { MMDiTRunner(ggml_backend_t backend, bool offload_params_to_cpu, + bool flash_attn, const String2GGMLType& tensor_types = {}, const std::string prefix = "") - : GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_types) { + : GGMLRunner(backend, offload_params_to_cpu), mmdit(flash_attn, tensor_types) { mmdit.init(params_ctx, tensor_types, prefix); } @@ -957,7 +970,7 @@ struct MMDiTRunner : public GGMLRunner { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_F16; - std::shared_ptr mmdit = std::shared_ptr(new MMDiTRunner(backend, false)); + std::shared_ptr mmdit = std::shared_ptr(new MMDiTRunner(backend, false, false)); { LOG_INFO("loading from '%s'", file_path.c_str()); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 08894731..51dc0b8a 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -349,6 +349,7 @@ class StableDiffusionGGML { model_loader.tensor_storages_types); diffusion_model = std::make_shared(backend, offload_params_to_cpu, + sd_ctx_params->diffusion_flash_attn, model_loader.tensor_storages_types); } else if (sd_version_is_flux(version)) { bool is_chroma = false;