From 72113b1a993a11009d47a7ad1a2af3643ee4fef0 Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 24 Jan 2026 22:25:55 +0800 Subject: [PATCH 1/7] make flux faster --- flux.hpp | 70 ++++++++++++++++++++------------------------------------ 1 file changed, 25 insertions(+), 45 deletions(-) diff --git a/flux.hpp b/flux.hpp index 77a65c557..f37760a78 100644 --- a/flux.hpp +++ b/flux.hpp @@ -103,7 +103,7 @@ namespace Flux { auto norm = std::dynamic_pointer_cast(blocks["norm"]); auto qkv = qkv_proj->forward(ctx, x); - auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); + auto qkv_vec = ggml_ext_chunk(ctx->ggml_ctx, qkv, 3, 0, true); int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); @@ -376,26 +376,23 @@ namespace Flux { auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head] - attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head] auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], - attn->ne[1], txt->ne[1], + attn->ne[2], attn->nb[1], attn->nb[2], - 0); // [n_txt_token, N, hidden_size] - txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size] + 0); // [N, n_txt_token, hidden_size] auto img_attn_out = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], - attn->ne[1], img->ne[1], + attn->ne[2], attn->nb[1], attn->nb[2], - attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] - img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] + txt->ne[1] * attn->nb[1]); // [N, n_img_token, hidden_size] // calculate the img bloks img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate)); @@ -492,37 +489,23 @@ namespace Flux { } auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); - auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] - qkv_mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token] - - auto qkv = ggml_view_3d(ctx->ggml_ctx, - qkv_mlp, - qkv_mlp->ne[0], - qkv_mlp->ne[1], - hidden_size * 3, - qkv_mlp->nb[1], - qkv_mlp->nb[2], - 0); // [hidden_size * 3 , N, n_token] - qkv = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3] - auto mlp = ggml_view_3d(ctx->ggml_ctx, - qkv_mlp, - qkv_mlp->ne[0], - qkv_mlp->ne[1], - mlp_hidden_dim * mlp_mult_factor, - qkv_mlp->nb[1], - qkv_mlp->nb[2], - qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim*mlp_mult_factor , N, n_token] - mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim*mlp_mult_factor] - - auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); // q,k,v: [N, n_token, hidden_size] + auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim*mlp_mult_factor] + + auto q = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], 0); + auto k = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * qkv_mlp->nb[0]); + auto v = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 2 * qkv_mlp->nb[0]); + int64_t head_dim = hidden_size / num_heads; - auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] - auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] - auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] - q = norm->query_norm(ctx, q); - k = norm->key_norm(ctx, k); - auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size] + q = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, q), head_dim, num_heads, q->ne[1], q->ne[2]); // [N, n_token, n_head, d_head] + k = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, k), head_dim, num_heads, k->ne[1], k->ne[2]); // [N, n_token, n_head, d_head] + v = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, v), head_dim, num_heads, v->ne[1], v->ne[2]); // [N, n_token, n_head, d_head] + + q = norm->query_norm(ctx, q); + k = norm->key_norm(ctx, k); + auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size] + + auto mlp = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, mlp_hidden_dim * mlp_mult_factor, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 3 * qkv_mlp->nb[0]); if (use_yak_mlp) { mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false); } else if (use_mlp_silu_act) { @@ -580,13 +563,10 @@ namespace Flux { } else { auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); - auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size] - m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] - m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] - - int64_t offset = m->nb[1] * m->ne[1]; - shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] - scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size] + auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0); + shift = m_vec[0]; // [N, hidden_size] + scale = m_vec[1]; // [N, hidden_size] } x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale); From c7d4a6035de8f20264a4675edc475db481ef2ccc Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 24 Jan 2026 22:58:33 +0800 Subject: [PATCH 2/7] make flux a litter faster --- flux.hpp | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/flux.hpp b/flux.hpp index f37760a78..b2e742282 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1014,16 +1014,14 @@ namespace Flux { txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods); } - txt_img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] img = ggml_view_3d(ctx->ggml_ctx, txt_img, txt_img->ne[0], - txt_img->ne[1], img->ne[1], + txt_img->ne[2], txt_img->nb[1], txt_img->nb[2], - txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] - img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] + txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size] if (final_layer) { img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) @@ -1176,9 +1174,8 @@ namespace Flux { auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size] if (out->ne[1] > img_tokens) { - out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] - out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); - out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size] + out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], img_tokens, out->ne[2], out->nb[1], out->nb[2], 0); + out = ggml_cont(ctx->ggml_ctx, out); } // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) From 6f4b49239c7be4d414383845cfa5aaca9e79aef2 Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 24 Jan 2026 23:46:54 +0800 Subject: [PATCH 3/7] make z-image a litter faster --- flux.hpp | 16 ++++++++-------- ggml_extend.hpp | 41 +++++++++-------------------------------- z_image.hpp | 38 ++++++++++++++++++++++++++++++-------- 3 files changed, 47 insertions(+), 48 deletions(-) diff --git a/flux.hpp b/flux.hpp index b2e742282..83a4a2242 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1014,14 +1014,14 @@ namespace Flux { txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods); } - img = ggml_view_3d(ctx->ggml_ctx, - txt_img, - txt_img->ne[0], - img->ne[1], - txt_img->ne[2], - txt_img->nb[1], - txt_img->nb[2], - txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size] + img = ggml_view_3d(ctx->ggml_ctx, + txt_img, + txt_img->ne[0], + img->ne[1], + txt_img->ne[2], + txt_img->nb[1], + txt_img->nb[2], + txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size] if (final_layer) { img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 9d5ea316b..692ba857d 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -687,42 +687,19 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx, struct ggml_tensor* x, int dim, int64_t start, - int64_t end) { + int64_t end, + bool cont = true) { GGML_ASSERT(dim >= 0 && dim < 4); - if (x->ne[dim] == 1) { - return x; - } - while (start < 0) { - start = x->ne[dim] + start; - } - while (end < 0) { - end = x->ne[dim] + end; - } - GGML_ASSERT(end > start); - GGML_ASSERT(start >= 0 && start < x->ne[dim]); - GGML_ASSERT(end > start && end <= x->ne[dim]); - int perm[4] = {0, 1, 2, 3}; - for (int i = dim; i < 3; ++i) - perm[i] = perm[i + 1]; - perm[3] = dim; - - int inv_perm[4]; - for (int i = 0; i < 4; ++i) - inv_perm[perm[i]] = i; - - if (dim != 3) { - x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]); - x = ggml_cont(ctx, x); - } + int64_t slice_size = end - start; + int64_t slice_ne[4] = {x->ne[0], x->ne[1], x->ne[2], x->ne[3]}; + slice_ne[dim] = slice_size; - x = ggml_view_4d( - ctx, x, - x->ne[0], x->ne[1], x->ne[2], end - start, - x->nb[1], x->nb[2], x->nb[3], x->nb[3] * start); + x = ggml_view_4d(ctx, x, + slice_ne[0], slice_ne[1], slice_ne[2], slice_ne[3], + x->nb[1], x->nb[2], x->nb[3], start * x->nb[dim]); - if (dim != 3) { - x = ggml_ext_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]); + if (cont) { x = ggml_cont(ctx, x); } diff --git a/z_image.hpp b/z_image.hpp index 0abc78320..505fa7e8a 100644 --- a/z_image.hpp +++ b/z_image.hpp @@ -54,15 +54,37 @@ namespace ZImage { auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim] qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim] - qkv = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, qkv, 0, 2, 3, 1)); // [num_heads + num_kv_heads*2, N, n_token, head_dim] - auto q = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], 0); // [num_heads, N, n_token, head_dim] - auto k = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * num_heads); // [num_kv_heads, N, n_token, head_dim] - auto v = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * (num_heads + num_kv_heads)); // [num_kv_heads, N, n_token, head_dim] - - q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 3, 1, 2)); // [N, n_token, num_heads, head_dim] - k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim] - v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim] + auto q = ggml_view_4d(ctx->ggml_ctx, + qkv, + qkv->ne[0], + num_heads, + qkv->ne[2], + qkv->ne[3], + qkv->nb[1], + qkv->nb[2], + qkv->nb[3], + 0); // [N, n_token, num_heads, head_dim] + auto k = ggml_view_4d(ctx->ggml_ctx, + qkv, + qkv->ne[0], + num_kv_heads, + qkv->ne[2], + qkv->ne[3], + qkv->nb[1], + qkv->nb[2], + qkv->nb[3], + num_heads * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim] + auto v = ggml_view_4d(ctx->ggml_ctx, + qkv, + qkv->ne[0], + num_kv_heads, + qkv->ne[2], + qkv->ne[3], + qkv->nb[1], + qkv->nb[2], + qkv->nb[3], + (num_heads + num_kv_heads) * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim] if (qk_norm) { auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); From e2600bd442bf6c588b5e03e15dcc3df3e32f3bf0 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 25 Jan 2026 00:25:06 +0800 Subject: [PATCH 4/7] make qwen image a litter faster --- ggml_extend.hpp | 15 +++++++++++++++ qwen_image.hpp | 15 +++++++-------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 692ba857d..f61f7c6c7 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -690,6 +690,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx, int64_t end, bool cont = true) { GGML_ASSERT(dim >= 0 && dim < 4); + if (x->ne[dim] == 1) { + return x; + } + while (start < 0) { + start = x->ne[dim] + start; + } + while (end < 0) { + end = x->ne[dim] + end; + } + GGML_ASSERT(end > start); + GGML_ASSERT(start >= 0 && start < x->ne[dim]); + GGML_ASSERT(end > start && end <= x->ne[dim]); int64_t slice_size = end - start; int64_t slice_ne[4] = {x->ne[0], x->ne[1], x->ne[2], x->ne[3]}; @@ -944,6 +956,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx, bool force_prec_f32 = false, float scale = 1.f) { if (scale != 1.f) { + if (!ggml_is_contiguous(x)) { + x = ggml_cont(ctx, x); + } x = ggml_scale(ctx, x, scale); } if (x->ne[2] * x->ne[3] > 1024) { diff --git a/qwen_image.hpp b/qwen_image.hpp index ec2231b01..dfa539788 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -162,26 +162,25 @@ namespace Qwen { auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] - attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] + auto attn = Rope::attention(ctx, q, k, v, pe, mask, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], - attn->ne[1], txt->ne[1], + attn->ne[2], attn->nb[1], attn->nb[2], - 0); // [n_txt_token, N, hidden_size] - txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size] + 0); // [N, n_txt_token, n_head*d_head] auto img_attn_out = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], - attn->ne[1], img->ne[1], + attn->ne[2], attn->nb[1], attn->nb[2], - attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size] - img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] + txt->ne[1] * attn->nb[1]); // [N, n_img_token, n_head*d_head] + img_attn_out = ggml_cont(ctx->ggml_ctx, img_attn_out); + txt_attn_out = ggml_cont(ctx->ggml_ctx, txt_attn_out); img_attn_out = to_out_0->forward(ctx, img_attn_out); txt_attn_out = to_add_out->forward(ctx, txt_attn_out); From 10fe4b094a11d350d4ebd256a909d1e909ffc6a1 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 25 Jan 2026 13:52:24 +0800 Subject: [PATCH 5/7] automatically make the parameters of some unary ops contiguous --- clip.hpp | 4 ++-- common.hpp | 8 +++---- esrgan.hpp | 4 ++-- flux.hpp | 4 ++-- ggml_extend.hpp | 64 +++++++++++++++++++++++++++++++++++++++---------- llm.hpp | 2 +- lora.hpp | 10 ++++---- mmdit.hpp | 2 +- pmid.hpp | 6 ++--- t5.hpp | 4 ++-- tae.hpp | 13 +++++----- unet.hpp | 4 ++-- vae.hpp | 4 ++-- wan.hpp | 8 +++---- z_image.hpp | 2 +- 15 files changed, 90 insertions(+), 49 deletions(-) diff --git a/clip.hpp b/clip.hpp index 7a6ebe9e7..4eec0241c 100644 --- a/clip.hpp +++ b/clip.hpp @@ -479,9 +479,9 @@ struct CLIPMLP : public GGMLBlock { x = fc1->forward(ctx, x); if (use_gelu) { - x = ggml_gelu_inplace(ctx->ggml_ctx, x); + x = ggml_ext_gelu(ctx->ggml_ctx, x, true); } else { - x = ggml_gelu_quick_inplace(ctx->ggml_ctx, x); + x = ggml_ext_gelu_quick(ctx->ggml_ctx, x, true); } x = fc2->forward(ctx, x); return x; diff --git a/common.hpp b/common.hpp index 13ab1038e..7183eb82e 100644 --- a/common.hpp +++ b/common.hpp @@ -200,7 +200,7 @@ class GEGLU : public UnaryBlock { gate = ggml_cont(ctx->ggml_ctx, gate); - gate = ggml_gelu_inplace(ctx->ggml_ctx, gate); + gate = ggml_ext_gelu(ctx->ggml_ctx, gate, true); x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out] @@ -220,7 +220,7 @@ class GELU : public UnaryBlock { auto proj = std::dynamic_pointer_cast(blocks["proj"]); x = proj->forward(ctx, x); - x = ggml_gelu_inplace(ctx->ggml_ctx, x); + x = ggml_ext_gelu(ctx->ggml_ctx, x, true); return x; } }; @@ -536,8 +536,8 @@ class AlphaBlender : public GGMLBlock { // image_only_indicator is always tensor([0.]) float alpha = get_alpha(); auto x = ggml_add(ctx->ggml_ctx, - ggml_scale(ctx->ggml_ctx, x_spatial, alpha), - ggml_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha)); + ggml_ext_scale(ctx->ggml_ctx, x_spatial, alpha), + ggml_ext_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha)); return x; } }; diff --git a/esrgan.hpp b/esrgan.hpp index 961e84f89..f740c2bc4 100644 --- a/esrgan.hpp +++ b/esrgan.hpp @@ -51,7 +51,7 @@ class ResidualDenseBlock : public GGMLBlock { x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2); auto x5 = conv5->forward(ctx, x_cat); - x5 = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, x5, 0.2f), x); + x5 = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, x5, 0.2f), x); return x5; } }; @@ -76,7 +76,7 @@ class RRDB : public GGMLBlock { out = rdb2->forward(ctx, out); out = rdb3->forward(ctx, out); - out = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, out, 0.2f), x); + out = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, out, 0.2f), x); return out; } }; diff --git a/flux.hpp b/flux.hpp index 83a4a2242..ff8c18997 100644 --- a/flux.hpp +++ b/flux.hpp @@ -153,7 +153,7 @@ namespace Flux { if (use_mlp_silu_act) { x = ggml_ext_silu_act(ctx->ggml_ctx, x); } else { - x = ggml_gelu_inplace(ctx->ggml_ctx, x); + x = ggml_ext_gelu(ctx->ggml_ctx, x, true); } x = mlp_2->forward(ctx, x); return x; @@ -511,7 +511,7 @@ namespace Flux { } else if (use_mlp_silu_act) { mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp); } else { - mlp = ggml_gelu_inplace(ctx->ggml_ctx, mlp); + mlp = ggml_ext_gelu(ctx->ggml_ctx, mlp, true); } auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, mlp, 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] diff --git a/ggml_extend.hpp b/ggml_extend.hpp index f61f7c6c7..fedab3809 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -949,6 +949,49 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_group_norm_32(struct ggml_context return ggml_group_norm(ctx, a, 32, eps); } +__STATIC_INLINE__ struct ggml_tensor* ggml_ext_scale(struct ggml_context* ctx, + struct ggml_tensor* x, + float factor, + bool inplace = false) { + if (!ggml_is_contiguous(x)) { + x = ggml_cont(ctx, x); + } + if (inplace) { + x = ggml_scale_inplace(ctx, x, factor); + } else { + x = ggml_scale(ctx, x, factor); + } + return x; +} + +__STATIC_INLINE__ struct ggml_tensor* ggml_ext_gelu(struct ggml_context* ctx, + struct ggml_tensor* x, + bool inplace = false) { + if (!ggml_is_contiguous(x)) { + x = ggml_cont(ctx, x); + } + if (inplace) { + x = ggml_gelu_inplace(ctx, x); + } else { + x = ggml_gelu(ctx, x); + } + return x; +} + +__STATIC_INLINE__ struct ggml_tensor* ggml_ext_gelu_quick(struct ggml_context* ctx, + struct ggml_tensor* x, + bool inplace = false) { + if (!ggml_is_contiguous(x)) { + x = ggml_cont(ctx, x); + } + if (inplace) { + x = ggml_gelu_quick_inplace(ctx, x); + } else { + x = ggml_gelu_quick(ctx, x); + } + return x; +} + __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* w, @@ -956,10 +999,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx, bool force_prec_f32 = false, float scale = 1.f) { if (scale != 1.f) { - if (!ggml_is_contiguous(x)) { - x = ggml_cont(ctx, x); - } - x = ggml_scale(ctx, x, scale); + x = ggml_ext_scale(ctx, x, scale); } if (x->ne[2] * x->ne[3] > 1024) { // workaround: avoid ggml cuda error @@ -978,7 +1018,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_linear(struct ggml_context* ctx, } } if (scale != 1.f) { - x = ggml_scale(ctx, x, 1.f / scale); + x = ggml_ext_scale(ctx, x, 1.f / scale); } if (b != nullptr) { x = ggml_add_inplace(ctx, x, b); @@ -1047,7 +1087,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx, bool circular_y = false, float scale = 1.f) { if (scale != 1.f) { - x = ggml_scale(ctx, x, scale); + x = ggml_ext_scale(ctx, x, scale); } if (w->ne[2] != x->ne[2] && ggml_n_dims(w) == 2) { w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], w->ne[1]); @@ -1065,7 +1105,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_conv_2d(struct ggml_context* ctx, x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1); } if (scale != 1.f) { - x = ggml_scale(ctx, x, 1.f / scale); + x = ggml_ext_scale(ctx, x, 1.f / scale); } if (b != nullptr) { b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1); @@ -1163,7 +1203,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_full(struct ggml_context* ctx, int64_t ne2, int64_t ne3) { auto one = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:one"); - auto t = ggml_scale(ctx, one, value); // [1,] + auto t = ggml_ext_scale(ctx, one, value); // [1,] t = ggml_repeat_4d(ctx, t, ne0, ne1, ne2, ne3); // [ne0, ne1, ne2, ne3] return t; } @@ -1263,7 +1303,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context k_in = ggml_pad(ctx, k_in, 0, kv_pad, 0, 0); } if (kv_scale != 1.0f) { - k_in = ggml_scale(ctx, k_in, kv_scale); + k_in = ggml_ext_scale(ctx, k_in, kv_scale); } k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16); @@ -1273,7 +1313,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0); } if (kv_scale != 1.0f) { - v_in = ggml_scale(ctx, v_in, kv_scale); + v_in = ggml_ext_scale(ctx, v_in, kv_scale); } v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16); @@ -1305,7 +1345,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention_ext(struct ggml_context auto out = ggml_flash_attn_ext(ctx, q_in, k_in, v_in, mask_in, scale / kv_scale, 0, 0); ggml_flash_attn_ext_set_prec(out, GGML_PREC_F32); if (kv_scale != 1.0f) { - out = ggml_scale(ctx, out, 1.0f / kv_scale); + out = ggml_ext_scale(ctx, out, 1.0f / kv_scale); } return out; }; @@ -1515,7 +1555,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_timestep_embedding( int dim, int max_period = 10000, float time_factor = 1.0f) { - timesteps = ggml_scale(ctx, timesteps, time_factor); + timesteps = ggml_ext_scale(ctx, timesteps, time_factor); return ggml_timestep_embedding(ctx, timesteps, dim, max_period); } diff --git a/llm.hpp b/llm.hpp index 781774db7..7feb8d3c8 100644 --- a/llm.hpp +++ b/llm.hpp @@ -638,7 +638,7 @@ namespace LLM { x = ln_q->forward(ctx, x); x = ggml_reshape_2d(ctx->ggml_ctx, x, hidden_size, ggml_nelements(x) / hidden_size); x = mlp_0->forward(ctx, x); - x = ggml_gelu(ctx->ggml_ctx, x); + x = ggml_ext_gelu(ctx->ggml_ctx, x); x = mlp_2->forward(ctx, x); return x; } diff --git a/lora.hpp b/lora.hpp index 7d83ec5cd..e5d9906ff 100644 --- a/lora.hpp +++ b/lora.hpp @@ -195,7 +195,7 @@ struct LoraModel : public GGMLRunner { scale_value *= multiplier; auto curr_updown = ggml_ext_merge_lora(ctx, lora_down, lora_up, lora_mid); - curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); + curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true); if (updown == nullptr) { updown = curr_updown; @@ -235,7 +235,7 @@ struct LoraModel : public GGMLRunner { float scale_value = 1.0f; scale_value *= multiplier; - curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); + curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true); if (updown == nullptr) { updown = curr_updown; @@ -340,7 +340,7 @@ struct LoraModel : public GGMLRunner { struct ggml_tensor* updown_1 = ggml_ext_merge_lora(ctx, hada_1_down, hada_1_up, hada_1_mid); struct ggml_tensor* updown_2 = ggml_ext_merge_lora(ctx, hada_2_down, hada_2_up, hada_2_mid); auto curr_updown = ggml_mul_inplace(ctx, updown_1, updown_2); - curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); + curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true); if (updown == nullptr) { updown = curr_updown; } else { @@ -456,7 +456,7 @@ struct LoraModel : public GGMLRunner { scale_value *= multiplier; auto curr_updown = ggml_ext_kronecker(ctx, lokr_w1, lokr_w2); - curr_updown = ggml_scale_inplace(ctx, curr_updown, scale_value); + curr_updown = ggml_ext_scale(ctx, curr_updown, scale_value, true); if (updown == nullptr) { updown = curr_updown; @@ -634,7 +634,7 @@ struct LoraModel : public GGMLRunner { forward_params.conv2d.scale); } - auto curr_out_diff = ggml_scale_inplace(ctx, lx, scale_value); + auto curr_out_diff = ggml_ext_scale(ctx, lx, scale_value, true); if (out_diff == nullptr) { out_diff = curr_out_diff; diff --git a/mmdit.hpp b/mmdit.hpp index cb4be7bf0..e2636e713 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -33,7 +33,7 @@ struct Mlp : public GGMLBlock { auto fc2 = std::dynamic_pointer_cast(blocks["fc2"]); x = fc1->forward(ctx, x); - x = ggml_gelu_inplace(ctx->ggml_ctx, x); + x = ggml_ext_gelu(ctx->ggml_ctx, x, true); x = fc2->forward(ctx, x); return x; } diff --git a/pmid.hpp b/pmid.hpp index 5b70dab66..8ce78d3a6 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -33,7 +33,7 @@ struct FuseBlock : public GGMLBlock { x = layer_norm->forward(ctx, x); // x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b); x = fc1->forward(ctx, x); - x = ggml_gelu_inplace(ctx->ggml_ctx, x); + x = ggml_ext_gelu(ctx->ggml_ctx, x, true); x = fc2->forward(ctx, x); // x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b); if (use_residue) @@ -129,8 +129,8 @@ struct PerceiverAttention : public GGMLBlock { k = reshape_tensor(ctx->ggml_ctx, k, heads); v = reshape_tensor(ctx->ggml_ctx, v, heads); scale = 1.f / sqrt(sqrt((float)dim_head)); - k = ggml_scale_inplace(ctx->ggml_ctx, k, scale); - q = ggml_scale_inplace(ctx->ggml_ctx, q, scale); + k = ggml_ext_scale(ctx->ggml_ctx, k, scale, true); + q = ggml_ext_scale(ctx->ggml_ctx, q, scale, true); // auto weight = ggml_mul_mat(ctx, q, k); auto weight = ggml_mul_mat(ctx->ggml_ctx, k, q); // NOTE order of mul is opposite to pytorch diff --git a/t5.hpp b/t5.hpp index 2e3e4b560..fdac3475f 100644 --- a/t5.hpp +++ b/t5.hpp @@ -515,7 +515,7 @@ struct T5DenseGatedActDense : public UnaryBlock { auto wi_1 = std::dynamic_pointer_cast(blocks["wi_1"]); auto wo = std::dynamic_pointer_cast(blocks["wo"]); - auto hidden_gelu = ggml_gelu_inplace(ctx->ggml_ctx, wi_0->forward(ctx, x)); + auto hidden_gelu = ggml_ext_gelu(ctx->ggml_ctx, wi_0->forward(ctx, x), true); auto hidden_linear = wi_1->forward(ctx, x); x = ggml_mul_inplace(ctx->ggml_ctx, hidden_gelu, hidden_linear); x = wo->forward(ctx, x); @@ -608,7 +608,7 @@ class T5Attention : public GGMLBlock { } } - k = ggml_scale_inplace(ctx->ggml_ctx, k, ::sqrtf(static_cast(d_head))); + k = ggml_ext_scale(ctx->ggml_ctx, k, ::sqrtf(static_cast(d_head)), true); x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, mask); // [N, n_token, d_head * n_head] diff --git a/tae.hpp b/tae.hpp index a22db1967..831525781 100644 --- a/tae.hpp +++ b/tae.hpp @@ -161,9 +161,9 @@ class TinyDecoder : public UnaryBlock { // z: [n, z_channels, h, w] // return: [n, out_channels, h*8, w*8] - auto h = ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f); + auto h = ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f); h = ggml_tanh_inplace(ctx->ggml_ctx, h); - h = ggml_scale(ctx->ggml_ctx, h, 3.0f); + h = ggml_ext_scale(ctx->ggml_ctx, h, 3.0f); for (int i = 0; i < num_blocks * 3 + 10; i++) { if (blocks.find(std::to_string(i)) == blocks.end()) { @@ -400,10 +400,11 @@ class TinyVideoDecoder : public UnaryBlock { auto first_conv = std::dynamic_pointer_cast(blocks["1"]); // Clamp() - auto h = ggml_scale_inplace(ctx->ggml_ctx, - ggml_tanh_inplace(ctx->ggml_ctx, - ggml_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)), - 3.0f); + auto h = ggml_ext_scale(ctx->ggml_ctx, + ggml_tanh_inplace(ctx->ggml_ctx, + ggml_ext_scale(ctx->ggml_ctx, z, 1.0f / 3.0f)), + 3.0f, + true); h = first_conv->forward(ctx, h); h = ggml_relu_inplace(ctx->ggml_ctx, h); diff --git a/unet.hpp b/unet.hpp index 6e15e1f45..2dd79e0e1 100644 --- a/unet.hpp +++ b/unet.hpp @@ -529,7 +529,7 @@ class UnetModelBlock : public GGMLBlock { } } if (controls.size() > 0) { - auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[controls.size() - 1], control_strength); + auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[controls.size() - 1], control_strength, true); h = ggml_add(ctx->ggml_ctx, h, cs); // middle control } int control_offset = static_cast(controls.size() - 2); @@ -542,7 +542,7 @@ class UnetModelBlock : public GGMLBlock { hs.pop_back(); if (controls.size() > 0) { - auto cs = ggml_scale_inplace(ctx->ggml_ctx, controls[control_offset], control_strength); + auto cs = ggml_ext_scale(ctx->ggml_ctx, controls[control_offset], control_strength, true); h_skip = ggml_add(ctx->ggml_ctx, h_skip, cs); // control net condition control_offset--; } diff --git a/vae.hpp b/vae.hpp index 232500295..fdddc8ae5 100644 --- a/vae.hpp +++ b/vae.hpp @@ -253,8 +253,8 @@ class VideoResnetBlock : public ResnetBlock { float alpha = get_alpha(); x = ggml_add(ctx->ggml_ctx, - ggml_scale(ctx->ggml_ctx, x, alpha), - ggml_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha)); + ggml_ext_scale(ctx->ggml_ctx, x, alpha), + ggml_ext_scale(ctx->ggml_ctx, x_mix, 1.0f - alpha)); x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w) x = ggml_reshape_4d(ctx->ggml_ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w diff --git a/wan.hpp b/wan.hpp index 3ade14bfe..1f1b2cf4b 100644 --- a/wan.hpp +++ b/wan.hpp @@ -1576,7 +1576,7 @@ namespace WAN { y = modulate_add(ctx->ggml_ctx, y, es[3]); y = ffn_0->forward(ctx, y); - y = ggml_gelu_inplace(ctx->ggml_ctx, y); + y = ggml_ext_gelu(ctx->ggml_ctx, y, true); y = ffn_2->forward(ctx, y); x = ggml_add(ctx->ggml_ctx, x, modulate_mul(ctx->ggml_ctx, y, es[5])); @@ -1723,7 +1723,7 @@ namespace WAN { auto x = proj_0->forward(ctx, image_embeds); x = proj_1->forward(ctx, x); - x = ggml_gelu_inplace(ctx->ggml_ctx, x); + x = ggml_ext_gelu(ctx->ggml_ctx, x, true); x = proj_3->forward(ctx, x); x = proj_4->forward(ctx, x); @@ -1910,7 +1910,7 @@ namespace WAN { e0 = ggml_reshape_4d(ctx->ggml_ctx, e0, e0->ne[0] / 6, 6, e0->ne[1], e0->ne[2]); // [N, 6, dim] or [N, T, 6, dim] context = text_embedding_0->forward(ctx, context); - context = ggml_gelu(ctx->ggml_ctx, context); + context = ggml_ext_gelu(ctx->ggml_ctx, context); context = text_embedding_2->forward(ctx, context); // [N, context_txt_len, dim] int64_t context_img_len = 0; @@ -1949,7 +1949,7 @@ namespace WAN { auto result = vace_block->forward(ctx, c, x_orig, e0, pe, context, context_img_len); auto c_skip = result.first; c = result.second; - c_skip = ggml_scale(ctx->ggml_ctx, c_skip, vace_strength); + c_skip = ggml_ext_scale(ctx->ggml_ctx, c_skip, vace_strength); x = ggml_add(ctx->ggml_ctx, x, c_skip); } } diff --git a/z_image.hpp b/z_image.hpp index 505fa7e8a..cee23833a 100644 --- a/z_image.hpp +++ b/z_image.hpp @@ -517,7 +517,7 @@ namespace ZImage { out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w] out = ggml_ext_slice(ctx->ggml_ctx, out, 0, 0, W); // [N, C, H, W] - out = ggml_scale(ctx->ggml_ctx, out, -1.f); + out = ggml_ext_scale(ctx->ggml_ctx, out, -1.f); return out; } From a78d7ec9a0d1f15ee8e76060714cdc2c6c962cbc Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 25 Jan 2026 22:06:57 +0800 Subject: [PATCH 6/7] simplify mmdit code --- mmdit.hpp | 75 ++++++++++++++++++++++++------------------------------- 1 file changed, 32 insertions(+), 43 deletions(-) diff --git a/mmdit.hpp b/mmdit.hpp index e2636e713..47cdec0d0 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -284,23 +284,19 @@ struct DismantledBlock : public GGMLBlock { auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); - int64_t n_mods = 9; - auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] - m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] - m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size] - - int64_t offset = m->nb[1] * m->ne[1]; - auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] - auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] - auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] - - auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] - auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] - auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] - - auto shift_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size] - auto scale_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size] - auto gate_msa2 = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size] + int n_mods = 9; + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] + auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0); + + auto shift_msa = m_vec[0]; // [N, hidden_size] + auto scale_msa = m_vec[1]; // [N, hidden_size] + auto gate_msa = m_vec[2]; // [N, hidden_size] + auto shift_mlp = m_vec[3]; // [N, hidden_size] + auto scale_mlp = m_vec[4]; // [N, hidden_size] + auto gate_mlp = m_vec[5]; // [N, hidden_size] + auto shift_msa2 = m_vec[6]; // [N, hidden_size] + auto scale_msa2 = m_vec[7]; // [N, hidden_size] + auto gate_msa2 = m_vec[8]; // [N, hidden_size] auto x_norm = norm1->forward(ctx, x); @@ -322,22 +318,20 @@ struct DismantledBlock : public GGMLBlock { auto attn = std::dynamic_pointer_cast(blocks["attn"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); - int64_t n_mods = 6; + int n_mods = 6; if (pre_only) { n_mods = 2; } - auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] - m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] - m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size] + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] + auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0); - int64_t offset = m->nb[1] * m->ne[1]; - auto shift_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] - auto scale_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + auto shift_msa = m_vec[0]; // [N, hidden_size] + auto scale_msa = m_vec[1]; // [N, hidden_size] if (!pre_only) { - auto gate_msa = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] - auto shift_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] - auto scale_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] - auto gate_mlp = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] + auto gate_msa = m_vec[2]; // [N, hidden_size] + auto shift_mlp = m_vec[3]; // [N, hidden_size] + auto scale_mlp = m_vec[4]; // [N, hidden_size] + auto gate_mlp = m_vec[5]; // [N, hidden_size] auto attn_in = modulate(ctx->ggml_ctx, norm1->forward(ctx, x), shift_msa, scale_msa); @@ -500,26 +494,24 @@ block_mixing(GGMLRunnerContext* ctx, qkv.push_back(ggml_concat(ctx->ggml_ctx, context_qkv[i], x_qkv[i], 1)); } - auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size] - attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size] + auto attn = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, nullptr, false, false, ctx->flash_attn_enabled); // [N, n_context + n_token, hidden_size] + auto context_attn = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], - attn->ne[1], context->ne[1], + attn->ne[2], attn->nb[1], attn->nb[2], - 0); // [n_context, N, hidden_size] - context_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, context_attn, 0, 2, 1, 3)); // [N, n_context, hidden_size] + 0); // [N, n_context, hidden_size] auto x_attn = ggml_view_3d(ctx->ggml_ctx, attn, attn->ne[0], - attn->ne[1], x->ne[1], + attn->ne[2], attn->nb[1], attn->nb[2], - attn->nb[2] * context->ne[1]); // [n_token, N, hidden_size] - x_attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x_attn, 0, 2, 1, 3)); // [N, n_token, hidden_size] + context->ne[1] * attn->nb[1]); // [N, n_token, hidden_size] if (!context_block->pre_only) { context = context_block->post_attention(ctx, @@ -604,13 +596,10 @@ struct FinalLayer : public GGMLBlock { auto linear = std::dynamic_pointer_cast(blocks["linear"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); - auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size] - m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] - m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] - - int64_t offset = m->nb[1] * m->ne[1]; - auto shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] - auto scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size] + auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0); + auto shift = m_vec[0]; // [N, hidden_size] + auto scale = m_vec[1]; // [N, hidden_size] x = modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale); x = linear->forward(ctx, x); From 833240a0c8e98babfa995a86169fdb20164376e6 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 25 Jan 2026 22:43:55 +0800 Subject: [PATCH 7/7] simplify wan code --- mmdit.hpp | 4 ++-- wan.hpp | 7 ++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/mmdit.hpp b/mmdit.hpp index 47cdec0d0..086b444dc 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -285,8 +285,8 @@ struct DismantledBlock : public GGMLBlock { auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); int n_mods = 9; - auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] - auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0); + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, n_mods * hidden_size] + auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, n_mods, 0); auto shift_msa = m_vec[0]; // [N, hidden_size] auto scale_msa = m_vec[1]; // [N, hidden_size] diff --git a/wan.hpp b/wan.hpp index 1f1b2cf4b..c56e1f926 100644 --- a/wan.hpp +++ b/wan.hpp @@ -1442,11 +1442,8 @@ namespace WAN { int64_t dim = x->ne[0]; int64_t context_txt_len = context->ne[1] - context_img_len; - context = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context, 0, 2, 1, 3)); // [context_img_len + context_txt_len, N, dim] - auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_img_len, context->nb[1], context->nb[2], 0); - auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, N, context_txt_len, context->nb[1], context->nb[2], context_img_len * context->nb[2]); - context_img = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_img, 0, 2, 1, 3)); // [N, context_img_len, dim] - context_txt = ggml_ext_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, context_txt, 0, 2, 1, 3)); // [N, context_txt_len, dim] + auto context_img = ggml_view_3d(ctx->ggml_ctx, context, dim, context_img_len, N, context->nb[1], context->nb[2], 0); // [N, context_img_len, dim] + auto context_txt = ggml_view_3d(ctx->ggml_ctx, context, dim, context_txt_len, N, context->nb[1], context->nb[2], context_img_len * context->nb[1]); // [N, context_txt_len, dim] auto q = q_proj->forward(ctx, x); q = norm_q->forward(ctx, q);