Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,9 +479,9 @@ struct CLIPMLP : public GGMLBlock {

x = fc1->forward(ctx, x);
if (use_gelu) {
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
} else {
x = ggml_gelu_quick_inplace(ctx->ggml_ctx, x);
x = ggml_ext_gelu_quick(ctx->ggml_ctx, x, true);
}
x = fc2->forward(ctx, x);
return x;
Expand Down
8 changes: 4 additions & 4 deletions common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class GEGLU : public UnaryBlock {

gate = ggml_cont(ctx->ggml_ctx, gate);

gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
gate = ggml_ext_gelu(ctx->ggml_ctx, gate, true);

x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]

Expand All @@ -220,7 +220,7 @@ class GELU : public UnaryBlock {
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);

x = proj->forward(ctx, x);
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
return x;
}
};
Expand Down Expand Up @@ -536,8 +536,8 @@ class AlphaBlender : public GGMLBlock {
// image_only_indicator is always tensor([0.])
float alpha = get_alpha();
auto x = ggml_add(ctx->ggml_ctx,
ggml_scale(ctx->ggml_ctx, x_spatial, alpha),
ggml_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha));
ggml_ext_scale(ctx->ggml_ctx, x_spatial, alpha),
ggml_ext_scale(ctx->ggml_ctx, x_temporal, 1.0f - alpha));
return x;
}
};
Expand Down
4 changes: 2 additions & 2 deletions esrgan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ResidualDenseBlock : public GGMLBlock {
x_cat = ggml_concat(ctx->ggml_ctx, x_cat, x4, 2);
auto x5 = conv5->forward(ctx, x_cat);

x5 = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, x5, 0.2f), x);
x5 = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, x5, 0.2f), x);
return x5;
}
};
Expand All @@ -76,7 +76,7 @@ class RRDB : public GGMLBlock {
out = rdb2->forward(ctx, out);
out = rdb3->forward(ctx, out);

out = ggml_add(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx, out, 0.2f), x);
out = ggml_add(ctx->ggml_ctx, ggml_ext_scale(ctx->ggml_ctx, out, 0.2f), x);
return out;
}
};
Expand Down
97 changes: 37 additions & 60 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ namespace Flux {
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);

auto qkv = qkv_proj->forward(ctx, x);
auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv);
auto qkv_vec = ggml_ext_chunk(ctx->ggml_ctx, qkv, 3, 0, true);
int64_t head_dim = qkv_vec[0]->ne[0] / num_heads;
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]);
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]);
Expand Down Expand Up @@ -153,7 +153,7 @@ namespace Flux {
if (use_mlp_silu_act) {
x = ggml_ext_silu_act(ctx->ggml_ctx, x);
} else {
x = ggml_gelu_inplace(ctx->ggml_ctx, x);
x = ggml_ext_gelu(ctx->ggml_ctx, x, true);
}
x = mlp_2->forward(ctx, x);
return x;
Expand Down Expand Up @@ -376,26 +376,23 @@ namespace Flux {
auto k = ggml_concat(ctx->ggml_ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx->ggml_ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]

auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_txt_token + n_img_token, n_head*d_head]
auto txt_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn,
attn->ne[0],
attn->ne[1],
txt->ne[1],
attn->ne[2],
attn->nb[1],
attn->nb[2],
0); // [n_txt_token, N, hidden_size]
txt_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_attn_out, 0, 2, 1, 3)); // [N, n_txt_token, hidden_size]
0); // [N, n_txt_token, hidden_size]
auto img_attn_out = ggml_view_3d(ctx->ggml_ctx,
attn,
attn->ne[0],
attn->ne[1],
img->ne[1],
attn->ne[2],
attn->nb[1],
attn->nb[2],
attn->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
img_attn_out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
txt->ne[1] * attn->nb[1]); // [N, n_img_token, hidden_size]

// calculate the img bloks
img = ggml_add(ctx->ggml_ctx, img, ggml_mul(ctx->ggml_ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
Expand Down Expand Up @@ -492,43 +489,29 @@ namespace Flux {
}

auto x_mod = Flux::modulate(ctx->ggml_ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim]
qkv_mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]

auto qkv = ggml_view_3d(ctx->ggml_ctx,
qkv_mlp,
qkv_mlp->ne[0],
qkv_mlp->ne[1],
hidden_size * 3,
qkv_mlp->nb[1],
qkv_mlp->nb[2],
0); // [hidden_size * 3 , N, n_token]
qkv = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, qkv, 1, 2, 0, 3)); // [N, n_token, hidden_size * 3]
auto mlp = ggml_view_3d(ctx->ggml_ctx,
qkv_mlp,
qkv_mlp->ne[0],
qkv_mlp->ne[1],
mlp_hidden_dim * mlp_mult_factor,
qkv_mlp->nb[1],
qkv_mlp->nb[2],
qkv_mlp->nb[2] * hidden_size * 3); // [mlp_hidden_dim*mlp_mult_factor , N, n_token]
mlp = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, mlp, 1, 2, 0, 3)); // [N, n_token, mlp_hidden_dim*mlp_mult_factor]

auto qkv_vec = split_qkv(ctx->ggml_ctx, qkv); // q,k,v: [N, n_token, hidden_size]
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim*mlp_mult_factor]

auto q = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], 0);
auto k = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * qkv_mlp->nb[0]);
auto v = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, hidden_size, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 2 * qkv_mlp->nb[0]);

int64_t head_dim = hidden_size / num_heads;
auto q = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head]
auto k = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head]
auto v = ggml_reshape_4d(ctx->ggml_ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k);
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]

q = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, q), head_dim, num_heads, q->ne[1], q->ne[2]); // [N, n_token, n_head, d_head]
k = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, k), head_dim, num_heads, k->ne[1], k->ne[2]); // [N, n_token, n_head, d_head]
v = ggml_reshape_4d(ctx->ggml_ctx, ggml_cont(ctx->ggml_ctx, v), head_dim, num_heads, v->ne[1], v->ne[2]); // [N, n_token, n_head, d_head]

q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k);
auto attn = Rope::attention(ctx, q, k, v, pe, mask); // [N, n_token, hidden_size]

auto mlp = ggml_view_3d(ctx->ggml_ctx, qkv_mlp, mlp_hidden_dim * mlp_mult_factor, qkv_mlp->ne[1], qkv_mlp->ne[2], qkv_mlp->nb[1], qkv_mlp->nb[2], hidden_size * 3 * qkv_mlp->nb[0]);
if (use_yak_mlp) {
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp, false);
} else if (use_mlp_silu_act) {
mlp = ggml_ext_silu_act(ctx->ggml_ctx, mlp);
} else {
mlp = ggml_gelu_inplace(ctx->ggml_ctx, mlp);
mlp = ggml_ext_gelu(ctx->ggml_ctx, mlp, true);
}
auto attn_mlp = ggml_concat(ctx->ggml_ctx, attn, mlp, 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
Expand Down Expand Up @@ -580,13 +563,10 @@ namespace Flux {
} else {
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);

auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
m = ggml_reshape_3d(ctx->ggml_ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
m = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]

int64_t offset = m->nb[1] * m->ne[1];
shift = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
scale = ggml_view_2d(ctx->ggml_ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, 2 * hidden_size]
auto m_vec = ggml_ext_chunk(ctx->ggml_ctx, m, 2, 0);
shift = m_vec[0]; // [N, hidden_size]
scale = m_vec[1]; // [N, hidden_size]
}

x = Flux::modulate(ctx->ggml_ctx, norm_final->forward(ctx, x), shift, scale);
Expand Down Expand Up @@ -1034,16 +1014,14 @@ namespace Flux {
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
}

txt_img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
img = ggml_view_3d(ctx->ggml_ctx,
txt_img,
txt_img->ne[0],
txt_img->ne[1],
img->ne[1],
txt_img->nb[1],
txt_img->nb[2],
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
img = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
img = ggml_view_3d(ctx->ggml_ctx,
txt_img,
txt_img->ne[0],
img->ne[1],
txt_img->ne[2],
txt_img->nb[1],
txt_img->nb[2],
txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size]

if (final_layer) {
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
Expand Down Expand Up @@ -1196,9 +1174,8 @@ namespace Flux {
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]

if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
out = ggml_view_3d(ctx->ggml_ctx, out, out->ne[0], img_tokens, out->ne[2], out->nb[1], out->nb[2], 0);
out = ggml_cont(ctx->ggml_ctx, out);
}

// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
Expand Down
Loading
Loading