From bff75f0db3b3e71f2f642813f6f39be13c3e8347 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Fri, 27 Jun 2025 20:20:43 -0400 Subject: [PATCH] add omnigen2 Signed-off-by: Vladimir Mandic --- .pylintrc | 1 + .ruff.toml | 1 + CHANGELOG.md | 5 + TODO.md | 1 - html/reference.json | 7 + models/Reference/OmniGen2--OmniGen2.jpg | Bin 0 -> 51113 bytes modules/model_omnigen.py | 15 - modules/model_omnigen2.py | 49 + modules/modeldata.py | 2 + modules/omnigen2/__init__.py | 3 + modules/omnigen2/image_processor.py | 265 ++++ modules/omnigen2/import_utils.py | 46 + modules/omnigen2/models/__init__.py | 0 .../omnigen2/models/attention_processor.py | 357 +++++ modules/omnigen2/models/embeddings.py | 126 ++ .../omnigen2/models/transformers/__init__.py | 3 + .../models/transformers/block_lumina2.py | 217 +++ .../models/transformers/components.py | 4 + modules/omnigen2/models/transformers/repo.py | 129 ++ .../transformers/transformer_omnigen2.py | 617 ++++++++ modules/omnigen2/pipeline_omnigen2.py | 718 ++++++++++ modules/omnigen2/pipeline_utils.py | 62 + modules/omnigen2/triton_layer_norm.py | 1257 +++++++++++++++++ modules/processing_args.py | 6 +- modules/sd_detect.py | 3 + modules/sd_models.py | 5 + modules/sd_offload.py | 2 +- modules/sd_samplers_common.py | 2 +- modules/shared.py | 8 +- wiki | 2 +- 30 files changed, 3886 insertions(+), 27 deletions(-) create mode 100644 models/Reference/OmniGen2--OmniGen2.jpg create mode 100644 modules/model_omnigen2.py create mode 100644 modules/omnigen2/__init__.py create mode 100644 modules/omnigen2/image_processor.py create mode 100644 modules/omnigen2/import_utils.py create mode 100644 modules/omnigen2/models/__init__.py create mode 100644 modules/omnigen2/models/attention_processor.py create mode 100644 modules/omnigen2/models/embeddings.py create mode 100644 modules/omnigen2/models/transformers/__init__.py create mode 100644 modules/omnigen2/models/transformers/block_lumina2.py create mode 100644 modules/omnigen2/models/transformers/components.py create mode 100644 modules/omnigen2/models/transformers/repo.py create mode 100644 modules/omnigen2/models/transformers/transformer_omnigen2.py create mode 100644 modules/omnigen2/pipeline_omnigen2.py create mode 100644 modules/omnigen2/pipeline_utils.py create mode 100644 modules/omnigen2/triton_layer_norm.py diff --git a/.pylintrc b/.pylintrc index 2c9bbc0c2..5ecbbef48 100644 --- a/.pylintrc +++ b/.pylintrc @@ -27,6 +27,7 @@ ignore-paths=/usr/lib/.*$, modules/meissonic, modules/mod, modules/omnigen, + modules/omnigen2, modules/onnx_impl, modules/pag, modules/pixelsmith, diff --git a/.ruff.toml b/.ruff.toml index ab5e0601c..023678c33 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -20,6 +20,7 @@ exclude = [ "modules/meissonic", "modules/mod", "modules/omnigen", + "modules/omnigen2", "modules/hidream", "modules/pag", "modules/pixelsmith", diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ed2c89c0..6c11b9be5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ - **Models** - [Models Wiki page](https://vladmandic.github.io/sdnext-docs/Models/) is updated will all new models + *note* all new image models larger than 30GB, so [offloading](https://vladmandic.github.io/sdnext-docs/Offload/) and [quantization](https://vladmandic.github.io/sdnext-docs/Quantization/) are necessary! + - [OmniGen2](https://huggingface.co/OmniGen2/OmniGen2) + - OmniGen2 is a powerful unified multimodal model that supports t2i and i2i workflows and uses 4B transformer with Qwen-VL-2.5 4B VLM + - available via *networks -> models -> reference* - [nVidia Cosmos-Predict2 T2I](https://research.nvidia.com/labs/dir/cosmos-predict2/) *2B and 14B* - Cosmos-Predict2 T2I is a new foundational model from Nvidia in two variants: small 2B and large 14B - available via *networks -> models -> reference* @@ -11,6 +15,7 @@ - *note*: this is a gated model, you need to [accept terms](https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image) and set your [huggingface token](https://vladmandic.github.io/sdnext-docs/Gated/) - [Black Forest Labs FLUX.1 Kontext I2I](https://bfl.ai/announcements/flux-1-kontext-dev) *Dev* variant - FLUX.1-Kontext is a 12B model billion parameter capable of editing images based on text instructions + - model is primarily designed for image editing workflows, but also works for text-to-image workflows - requirements are similar to regular FLUX.1 although 2x slower - available via *networks -> models -> reference* - *note*: this is a gated model, you need to [accept terms](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) and set your [huggingface token](https://vladmandic.github.io/sdnext-docs/Gated/) diff --git a/TODO.md b/TODO.md index c8d6183e9..17963a7af 100644 --- a/TODO.md +++ b/TODO.md @@ -52,7 +52,6 @@ Main ToDo list can be found at [GitHub projects](https://github.com/users/vladma - [SkyReels-v2](https://github.com/SkyworkAI/SkyReels-V2)(https://github.com/huggingface/diffusers/pull/11518) #### External:Unified/MultiModal - [Bagel](https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT)(https://github.com/bytedance-seed/bagel) -- [OmniGen2](https://huggingface.co/OmniGen2/OmniGen2) - [Ming](https://github.com/inclusionAI/Ming) - [Liquid](https://github.com/FoundationVision/Liquid) #### External:Image2Image/Editing diff --git a/html/reference.json b/html/reference.json index 599ad1c5b..e90aad1d1 100644 --- a/html/reference.json +++ b/html/reference.json @@ -252,6 +252,13 @@ "skip": true }, + "VectorSpaceLab OmniGen v2": { + "path": "OmniGen2/OmniGen2", + "desc": "OmniGen2 is a powerful and efficient unified multimodal model. Unlike OmniGen v1, OmniGen2 features two distinct decoding pathways for text and image modalities, utilizing unshared parameters and a decoupled image tokenizer.", + "preview": "OmniGen2--OmniGen2.jpg", + "skip": true + }, + "AuraFlow 0.3": { "path": "fal/AuraFlow-v0.3", "desc": "AuraFlow v0.3 is the fully open-sourced flow-based text-to-image generation model. The model was trained with more compute compared to the previous version, AuraFlow-v0.2. Compared to AuraFlow-v0.2, the model is fine-tuned on more aesthetic datasets and now supports various aspect ratio, (now width and height up to 1536 pixels).", diff --git a/models/Reference/OmniGen2--OmniGen2.jpg b/models/Reference/OmniGen2--OmniGen2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d8fce4558842e4d53b78983701a4ec069f1391a9 GIT binary patch literal 51113 zcmbTdbyyrt(>}U{;1*ni6Wm!G7D8Z=;O@cQH4r=r?(QMD>*DSZ+}+)RyPQp)=Y7AQ z{BzD3u36^l>FMsOo~gU)?&_Dhmo>l}X$dI_01ONa;1l!@@Ujjt7k9TX1ps7a0So{D z01@ySh6n%$y@G+B!Z5`Dy_SHX1Hk?*e+9kC4uJjwFrlX$^!#&tf}Ve#`TO7Az{1v< z#n!}`9LU1P!p;fpX%Qd>KtV)AMnphCMnXnIMM1+P!NSDAz@#K3!XcrjW@MnFrln=( z5aD5F6Jn#K<^9MfBqlB?CCSJwuPP^@A_A5a|I-KzDjFIl1||g-7KJzqEsOa7=i{Xv zfQk4DI(}FfEWj&F7+6f0mu>(Vbo6jgz5G$lzX!}K*w=9I2#83?D9{R!H-J|#u&}RQ z!@|M6ehsbd1-%b=jR}WE!7d7qt!RKi>43xG7oCMj^}ezLS84o|n$ytH9|;)`pMa2v z<{d2^Jp&hzn}?T=U+jap1Xxl^T3JO^O*^aCJG;7jdi(kZ1}7$`re|j7<`>pCHn+BS zcK7xV&dx6`udZ)y@9zJ|1p}4yKgU0U{U35+Lgjk(`Zerpgg8M`u-bAW?BDo#Glgjw9nybFI^y{Soc2WdH93^Z)-P*?$H5AGsC) zsIV~5$%DlNfB<`!_?|ZWxRKxHcpvHSrC$I)#GcolRlZu3WeF)XakS&PcbLex2Gn1X zZurMA4VUVG-O!pXnpoxy`?L_`7Szg9)%u;M#8rc56As zYh0Gl(4P+eyYz5R1B>E4gn7jZbc=P+U~tSm`3#M*ixScZO# zTegjN*dqD-#f7vcH0D6K_uS&-_uF9h;;*8MzMS{5YNWGxRV7BK^MVe;c!pofje#t0 zU69^JKU!Nu{hGlLcm!Ao^`aZW#YKESP8p*MooKdq2R%MiCbw?Nav4?asUR&%h#TPdGvCT8Df-1r6HJcbbX zYj|$76w9Y??S8uSt4E@5nQY|u8D~O@rCR>Xr13Q%+IoG*o^Ic>pcTE&1YT>tl#`EJ zp+qRkQJq5~`MFnNqceP$W51}#yne1}{0+ZOoG_ZK5I;?2ErNLj8?R!IT4QqFfKx5- zShX@dEhAUg<6gUk;egA~x3EivQ{Q4j#3*40sJ@xd!?Zq9IcF)wBu5l)`` z)gWg9jx$Ez2Tr z4W@)nsz_Ot>UQOtP+9DR+060B@D>+a7D$Q1($eBk&ukR@T;3>A^uNE}wOWP4aygGX zZQE37&P&vbS=X(V$k5Q6Kae#Wz&p{7m+W!o5r16&EX0s4d~1wZVMgMGrtPXlP*8it zsO27fMZX*Vlc7RVQRwa4KJ*-2k- zyF=CVnH}`uAH{s+^FYt+j`XgpcMf7TP!zv``_ePNCrzKwI`upahi|wQY3KtH5*43w z1XqFXF1B|6>0rxSbc`2( zUR8!pYxnn}vVxm72v4q~H};w2%+Cj*nd#;*38(XZ8twn%DW)1&kHK9vh(>^FasBq~9QhO4#@;LE)vlCnMZGc@;{$lb1poIK> zasp0%wOjlUobz)OE%M2e%t$ixFm=yrpPz)b<<{hR_@@^54YEHC0NFO+1wbAuf4{@H z!^m??EJWACTb!zI>G8el-)^RG{`7+VRCs{Q(iI~V{5*B_Uwu1~GM%l+%!r z3*mUAEW{JV>;C-MK-JLTRj##P`Maed1oL^(uvvI#!*gKhhJ&lBa(gkxr?Z8?0SD2-lFNezG^&fd!r>@ZOjsAIm)S+(;)m+Ty1xZ~wHVMS5g3KpvFZ#T$ zcvnuY_!mI%7qY7%ABIG*&i<*-Ec;{e{wS~4)PR+(u?rIiK=nZCSb4zwUBWv<2SQpq zLoah8nL?DDwa-zQX+Fi?nVFXL`6fpSXMt*V-t5m@G=(LGOSU47o=e>*NEAi`=Q&jH zwVPyg$Rd4Olkwzj@8c>4sW5Y77%CPzTKLs`FZ+kzMbXoYqeA@fVRKoWVpC8Dqk<`;N=i%y+O0qN z{eA|T)W)OE5$p=bajDMCaQ>!DJieZgfceNFEHCnT^Sl_}iRlejNv{7TGKW-wIkqQm zj-F+1R=sPCcEXO>c`rZtNIcBrPA*cHi|cv&$zuM7=hv zVwMi_8>UxjF$W1{v-Wr9BVQn+0-o4f4BW5g_IPnJ@`4N6T#|q19bi^@g=|?P>_~lI z1J6LX1V)dE6*6WAv)07;XYmc`N*)u8MuJLC+C3dP4C8QEOh+ocJ zPk!4H0JM7px7FWsH~7V8Y?q1B+RB*Cu4FQ3C(B^%Jk8Br`4xz|y-AhLATF4CeAa$t z)nfx;OlP?`cICiHp&H?5yA+0C`K|}aYy=;jb)k5lle5XWUzH5=-^(hEjH*o5-eHHy zBY|9+HutQAwTVX)WKx(UsWA_ET-Pvf^e@PsEXUJ7MY?pSs%$0*$8*al^YO{`Tj0V$ zHkKJVvQ$TKGB-$T-evSE1w?EaJqxeyOShd}zW}6-F}OBAe((D6*J-` zbtr>l>Ua=9P0#yQ)(bvp6Y?u@#&l$)W69to+X4ibdjTLHLVVCxEE5*A7pioFj$Pq` z;O6Yn(PiJq%xu>;F_OIF_`JW;*yIy0If6kA@wC?PP@dgfq-TSzac*XH|@)^IDew8Kqtu4IBPGV)MM{f^l)sDEXr; ze=#cNho3nm*R(35PBjfOza-139;)aJZfAj~&o6*uen#!eU)l|1e!KUxqlV`EoONQg z*gIfNaVI<+1-75QY>)q#?1wn_(M%M*7eIrqhWzG-k^Ih4RQapnOD@{`oqM8txxft1 zk-xTD|9*#V%x4SJ>IIWtd1QIWY%-(7WG6$qzpD;?jnCN@cJ{wgD zUN5E3XW=IEGf(g9e02`T$f7p8q-@lFOS{1)$9lCgLaNV)mB)g2J&nDTX^N0i(j9Z`@>Olf7YB`whXX&yv%ZE%1FEMx{Y913Gkw?S$o(??80ZydeA@{_pz zU-+g3TjZ>iR+;u&!1aeujD5Z9aiGs{BT|5@C)E=cR>Sld{L;z`pye4flyO!)Ar;(;?* zj^$`P1fY7A+C;5P9^S*Kj0ubE^v}<$TCCrVO^GXbF^W*~s%Y|nyLL(N zSJRlIB}oi0oUbLCIm~{ayH&%hfS^|6K5d4((#bK8Zw9WRs@FKuZb+y{pL0*gIJPDt z1cW`G`n6W-Lrk4WHZa-n`dkwwl!m}$@GQjAu23zX);N4-NC@28VAW_Q6{aMy&{`+3fC)Fv{&TRTLbtTD=4H8}hY zh8%@zSTd?pJ7lo6HBJf2pyK;<9wH2~A@Wj@WiYgb*_2~le++ij28?~|5~#dzxsRQl zC_7OUTmmx`yC3fxd44>tsdA4HSY&hyK~b9lf;=ao6rk@PNeO%~~;u^U$X6;eV)*Dm6)A5huN#z)j^h zN>~6H=3AI+BY6ef$cRCD!?tH$nQC*Qrhoh3lM8hL$Q-@sFzJ6xzFE~0GKU>()`X(V zugFr`P|y~i!&T`IF{;zm$M4o{huffZ+Cb12C)8T%(?1>xf={>i_x%%}{iHt~(1cjHFFnk-sCV|$5CA#n@$SpCY= zowOVE2%fc|%~6&Wf5>$_(RKqe*KG5+ZwAif39IdP#ss}vdSy+v>l-=}52wt*%14^Gbk>Pmba2kfyrW0^IHl%B-3KN?b6`UPnm@KQH^{US)c!NU76Fe=vL z$Zx%B+zd&=YSo5~KVbYQOJ*gvtT>9ape{6#jOwou-{+t&I9(%vB=^d%E~6oD(oP1> zKP2iusP*K;<^Jng+-miH_b+UPWTFVNgJyZgCi9$okXTH z3%3|+QmaSXxp)et23P^A==-cl9yrs5%H4Ow=m7b`K0 zo7&=RmqZV&-kQnRPh~HgN4`Sa;U?7W^Xpsl;s5!}ov1ibesVb-fg?&DBGF+`gO(HN zZH1&_X^G)#L#pkh{lU<3Y|-7R7th8ViyDFBr#*hA-299{* zSec!YSYX>3!*@_PWsGanTnw)Yz$0T?vreP;7)PGo6+oFG+}fEGsME!w`6DWu^mYlv zkWLU!o!J{7b#iP_RB{Rz_ua7GmJMTF)y_B`72lrEJns1nfnw_Bk$>zG5o&QYYl$f| zZsFW}z%*4-97_zxQ7o|KIts)saz1^Hs?Wt)NtfN=BvLVww=L~uVI?td2rG(b$$F&M zty}~rKh#cNKG2my%v0D2p60gE3*-peq*ZrI8)YS) zm?a&F%KQNmYX9!*{{5?Rd2>}TfQ~et;S5PivY15M#z@Y-2tj9x2dgFS0Ouq`ER1w& z3Q|p*aQ~B0GdFr$GmeoeiXG}6q%oj;2X(ibmSP3%;|LiY-?`sg9^ zyr26&Fh>oAxyt{-9OslE#;YZFL*BVGLb9yBGosA(jg!+$s^=8Y^H(c+>Km@FX)a(f ztDhf;Zyds|21>1hCps6&dZz^|#rb0sm4&i22E-x+R3+*)aLZVX5s|hs{J)~s6cn)e z!t605VuA>F^t6M zx#e@0287~lgay~wZm;q?ESa~wzOFBVq~pEw;z_(aoF2$6$3HvPk}vUunb}{K9x^)E z*%qh84gh}!2c*)aV-3>QYJ8IHre*kH1k8Jox5p5%G?C>gV2IKEn412seb6L@d<0#j z@3NRbeu~h^=%7V$>%M8eN9e@z%QK zaN_w6qkANxAZ6Kfs!vmyJYzR`8MP!?#yQCxK7C7>(%8?!nfo7y;>Ik32D z4R5ylW@39X&q?afieP%!BQ{u0g~ z8meQTLz*57M*&vTpWc)gV4U*n3~ekx&io>PMU<+_FZLc*OF3_s5m~PPywvsKK}Nf& zfpG2|SHTme?_$_&n4DwFnYC{c;zdoKg^Od)_5 zdu)cWN!31Y)Q$Yw*WFr0-g&QwCaZaNqa}<+a&{st<;wE~zbIS`9Py4X0e7xz&K_fP zra>QtxE%d9K}LJ-`ny8RMN|Wbp0xW+&y*9jEg05P*#||4KLWCq4kHdQxcwJlz+r7$ zYAr8-GMDME&s>s)r`26XP?vsH?1Rt0L8EdY zWUjpu2^<9&g%usIm$9T0YQS^SK1I%$4S1A%NZN(VRl1Cj{!y;NlK9#Ba)Y zrv64ez-DKv>}$(_2J6SB{`AKyNN>30nxyYTC`(R0?hcj2N2Xc#BdUWuGg_=BS)Q@9 zey`>oV2f}@#Q+Yn$K0N*vNLEwjDp*#Rt2CQ-J?}xV{)?NPj5ZE4@3AiNMgL;zX(mg z;+nxCeHe?KyC#bSMhZt3y31=pqNij7mx+J0$G!zu0)GW(5smL?XG=a@r=b)|KbrXz@aE^Uni-PhhLe3ktloXHs*eUxEtJ4|^2Fx&?WGnJsg zE!}(^aw`h+FcJdsyKKojSdRBmga;$oU-SL2xS_ScRI%fQ_iq||<*9(YEWQ%c`O&%z z@gN^ag0~O<6q5aUpuJZX#vbxj+yqr)0rS=7L{^MdlGOS5JHm&C1&7-+NjqKf3nE$Y zFQxbIlWo~!JiI)}CQ3LBoWj^~G;m5vsH?1XrqLvhv=REKw}#Y>x&J-;%rP-4xT^)40BNOlJlAekKcO<;c%lgf(~39@|w zY)Ildq^;tr*BywVm5T2?Jc>u9BsNY!B{{_7{|fZ z0iV?{ZF4L^%@S>}NQAixOp}ri~ON3M=@AcuGZ$M}1bEL-34J#An)X_AHll zKVC_Fyi^tw3!Vw*l}8l)6P)=ch1N!pZFNB~Fd|r!j@7Y~L+2;tB_~D-Yx}oEulJU6&8N9#KeT#0qVVb)*Uw z1ix-DKaXxpts0pGMqXr_=?YuZB@ftZwlsyTLp&F+KCR~|W;n7dQ_t&lvd^Xyzk&aL zGL?}Fl6DF8Dy->yzN8ZFt(wmv)-;=|6to{R<1XP*}kG3k#9x zKL3(mGwzY=zs8)OsXd4E?G^EtaL;#h1{ml~F~goA=nmTcp@zePt_MVF$Vk1i#`;wU z_SPY)?>J-YQVQhI7PPQ=hZ%baGdOG88r{|9BZuJe(E=u;C$9=kr%pD;#Jll^su|9R z&qek#bL}Krq+Nv4u}+YD z8^V^k(@SJX*gO3w#Xr2ol>q$Y%Z_?1!`M{IC$#RK6gyL0xyLo+!M?;zEtfLi3Q>sN z(gp0W>`tZX^>N7#fxVB2{fbKP8M3iKCmq|wp9N%&XbstB>N?AO44G}a8i^$I z=`Zz!V#KiXIw%~A;S19#lz{jXp8UhA=>@|4)qCp$c+kS;H$mtq%;*Z<5<_F(T5(`&O*YXhrwfq(Sw~!vG&wZBcKMWgAoM zGojPCRIhjZECbN6oUW~Y5Lh^{z#$RAThapV_OMN7jpthH=PyF*_Hvj=W4p3s>|}oY z#&EfQCPk&x5U^Rwr~QifJLwOMv(nb=;AaWrvp7)F0fO`{XdO5)k1M$I;e&X^T%+hpMR()hjj6-$QW6=ux9PoKz7n$^`T@SozE3?Nz&K# zyU&KU&0j{a;8LPD-R|#Yd7)tEahKFPLkVLg;d8IwT@xj!iUv!*grDjpH;&jFC78fh zK(kl5#o+>Pw5bbGkZ5AmjM9hw?u>pEQ>K4a3*Y6T(V-HHKw9^EhFHP!{cH#Gh?u7x z2gR7fCY@3L_$+Cxz{mSZ&W9np40LY*=g5Xg6ZHH&PF-0%E5-lq5ebXEN4J-7&9vd@ z4V56N3_DO3{rA1G^&Hce+{H;sfH+&4F@w|7Y8n(T)yO!XlP5LgOvl@)AAR#2BH~HY z{G7Sy>IazSUK?E$ct+iD|7Z^ZBJe88kiM3=#QlJ7^zmKN_gL9?Nr=5Ey}p!j0LD|g^aMCwNzE{-_ zQmOQVk~7KF&aZ^VqnBm5EY+YNqy__t;SCFCoGX3*eX~=3!rtlb4tn#mi~WJx>MG8J z%%`}eINsG%$$ftYK=NN2K3h)|JpKyy+$7)~g&Qu_@}dfS#I zJ^2SY30#VBX1H361R1A{WA=ok7{-uVUZvXN{UfSXMdm_6==S`2e8ZB21(0v@!La-8 z6x$z(4KqXw{k-jw553&F)LK5)RH%!T;c4nDRDEb>jkNx23K#Lt@m>5Qvvc0eJdDz( zRExzi*og|RbXn^w1PlsusR86v3DFTN!o3sY=~2G#PIHv@Nl9uacde3gV$RLyewH|#mj2dO&kszqkzjca zQ;f?#saU@RU-zptW8$ zeS;|&AdEsqjdA!8NN!YSPLC{wEU%*g9;or6!&go}E9zZ&n<5d!nWa=BA$`uTHvRN- zc~0K8{mofC;QqQSjalDBRqDt@P4Lu$gvS?xZB|U|#v8HsruB&LS+iuBg*H;QIfKJ{ zGIylPZ>cCV$y-ppDfp{pE<)4XMPc1rc>@G0AX?g?ixtkEs3r~%?iyt9Xrz0n{6E;V zcb51MS~ap|_6Fk$FG9>_X60)Ybh{WHwgm%ME#nqHxDsp_o@6RjG&j*c0dv!%9m z^D(^}jH#hUw(y8yPva3z_cL=L5QbqV+F$e7xCARBy8`%;0eVNCHWj;sqm-d`NESQC zB)ZLeDv}f`cggb72M!&)^Wv-HOL4niT9x^w=?p8{R=;jxjUQutoFgQBi29(tTnX6< zcqXN+Okg$&`lCcTOw@FI%I@C|P7Tu+Z}?7yP1mHXZUnSZf6#JYSkYNz*tJmXa@#%6 zt)E22#1GmAD@Md$HSqsBz4?apFxnbDrC~IR0{$R?J(!dxH5f>vz^$Fdv4;J@Q zt}hj5rC;TOc*k`s?S3rgOZMd3JvAh~%@Cg&AAZ4(Lka&WtC%1Sk~QaCQ#FaV?pa4D zUJ_zd{UMTocm&e8+5;sx=9{1a=u)C$OW9z0LrKtz$Z5TEKc9}v2!rH$w-3(>bNz&a zW9Ub>FcW9fJ-1|UxSDzaqxAj#%n+`%gl#1w75&eXZI+0C@bOC%M`iUlK zf`R$baDtkC=&Ns=&J0QrJWhe2Z!`?9_zI=CD-|s&a{Oc$=#Lho#>g*TC(AF z@E?`>xD4rY2`IJ=LHzO5HQa^xUay0|I|hzO2ssVIw)<-A!=}|7-ashY7u4*Log}z1Fh$?Qipz%Gu6!k6!mqV za|OD%al>U|!K$dNkl8G6BvcI4nDwRa{k9_e?Ioq}uhOIsnkXa~zFzrsexq$@k!H@t zc>#!3SFgGm+EQL9;)X=Um?4~%_`Et4{kEJIh}_O$e6tIhh$|WqlKZSweIi@%<)(H5 zI8{{1PV35ftIo5j-4m+xow@G>OXPksS1o6HsHpdiEB{R9u>wz`6DNjZKBLf_dRJMxusC1Xh?jP)!OpMFYY98Z&6(&D&JvM*(+3TqnOO--il@JENC zVV8&`gxb!+#LGF#(}wv-ZzOP=^#9g!v~pr`vC=t6RjVO8pHgSu%G7c-+zFoGt~TeF z(%FsvEjOgozZV#^8F?akHS(0?k;2q@#&3M)E?1=C#fTtR6Jixr<*$~W%oAytQSHoX zN3@2DAVd#zpJ2?B?8PC4vx6nHm0utIPwJLDG)OZ2n;|e|2nu8soxF|ytr%6gbqUM$ z2aX!{Ig=alJ#r^;!v|r%5v(Eav*fV1FT!h_GP-ETYf`}l?k1EBPJCfUEA}cgYG{Jn z#m%+y+SdF`4IdBwIO|PnpZ2-DFgs_zRihNe^^Dpxq`}yzT?k>Kp+#FKrE`g3S+R9X|D?*|pJG=spaa!%~`xxnmZxe&)(p_q@^Q6g7)I16r{wjNFyNBNlo zvHNpO1Vl+TeAV7YH-`iTqxyBGAlM)<1~8(Y2Ba7(W(rG@k+TzMeL6Pm-;I@&b0499 z!?K>H@CMyujEK~JR?%FKY1*Kk(Eor#*f#=WrzQhfD}O8V>9bCb5}2q7{f8|dD7rkW z6QYgY&^oX)trO36;JSk7>VsslcoF@ErXzk*U}(8!3v@?uG6$twfSF?1HT zurxLGYfWv!XLKwelVu_ILKv0nkmdXPfj!5WXS`U&spy<&qYy0A1B{_1WBNx#pngXN zWi(KGdkTG~HFhKzY0I^53(gCtLX~YQE)U{)N8)U$T*&?7T0p~QLrV2$bnp0}Gs2-J z|Mybn8pnQQF^bW|wv&_NlucP{52XOeA`tSHXjB7&sY^jlQcO<~cOS6OP>yjM;aOaF zk>Br#C_xpez_6)U8b*RMqOTIGe0me4Egz%LF(0v)A{sO3^%MbK*S!f9RN3@vKvKU*jNJJ8J zLl(TvBbGGSlW@2!5AQi|dN0t|@Oz?h`bq%d1)#DjO)M-ms-vC}KFH4pT|V+^dkDBw z=5l%gU_JQUAe@&M3s!5&N}vJch_cZ{ux$(KlnA+xQ14a5PxC_IM&HfQes=N9Ps7}i z>h*Lj3+>4uOF36`vouuT`?N$d;~0FP)GgY%NRYy2l}rk1A*xX5YA9_(v+!1%?)_lq znW+Y3;4u@@D$bhfBkDqy2tdw$mDWD9i#ckg5ZJuLej{O#(wu8cGpYk%L=HLH&pXBj zzo$KdecDvBCdZpm1TV3Iw8m|W2nCuFs%2?I4B>%lNF z9t*{xQMCV+fAfnk)BSpjD$-Yq&+zdPz4M2$=iHUi-Nc>Ez6jdD71tRy5kSKT3-9%% zlEtsrnVlHxB5(5p%YNvR@>0fYi7Uq;M9G6j>Ti!(aJwX`;NKz#r1Yws{C1k^+CdM) z71l+MRW?2E_`kC$7PHHipS3)7r-msoDzCnG^cohkQa z%6ky-%GZ-Y3cU@zwo;0&UwUQ}l5UjPQx{_$dwdcfRM+K794&|c}7f{t^n0DlSZ{$Fu*&rOZ z%>#64K5igOY{H%P85~1N%@(^kq?5iZz>!itlMN6vAj2V~O%Xs$5%1QnvSRCVga{s$af8;z7!lhG zcjJV-io_<8-hBufuN!pyo(5Fm6`dCHJolp9eLWQ9N=Rwuj;m($W}))z91wUJ8LVhz;c z^`AU2RnDb;IsMi*P<#LvZcRj1tARy5Hw1>wWTS20+xXJcAgu7=hU0EkER0%zMa=I# zi>T)#8*DyE5T(_|s^HS!V#zWU`IzE-Jm#NhmCr6l6eQyYMDj+s^LhLz_ zP9(x_e5w~Odps(-6iVT==(!%Dv3AFXY0mzt>XvzH?(1(k9R3QInugE6c^GeG@BSKfco;U5THTQwP7*;H~N$QhQ&1C>KNi z7=fmi^HA=n*Ct&OHcrAL%|c7sm8YqUfi%LObs|*ssXZ)&sGc6!AbPyDdL0Tr9C+_u ztYaGVps1v)%jV1WsxU!=<&1{3#E01AQ80O#Y$?$r~eCumRbV z@;@9|nX_hM4%sEjAB$cl*VV1As4d}}hbywLulCthV9Y8?A4k^*>c|qY;0cgV50nR= z1dYaHKJLMmF6$KCW_X{9-=)0jQ|V)XT5($M?@L%-M81`tT1hln2xU$h9y{5sh(kdl zPB)?@TfcvDf3^9|nLVC`sX_G$DR$>w6|;UWYB0OdFKiAW5CB*#sfrtNo~8#Q>V%3Z zzK9w6M2KFAshL_~Yntg-fX)}m6h*?ZthlT^Je?cmyT>@rfHDtC1C2VpK5Hs4)+t4H zHj6m}cav`(VxL24AM%Qqbl=|4&}?p{zA@ip3Mm0~)A}pDbs5 zeKzzml^+iivcL4HQm4seOI67w4BV+U4j#?M<9>^8e4116#x=sRRFVwR?&X)rJ}qWQ z{@Ftp6%^;Jc&9d!nx%7~&jE;r4WROo&c(aNjx%av*AyzQkL;|)+tXXm5Xq*cg}p4g zmMS9h7CXMnP~foo$+a9tqeh>$l1yl$G)1>NtiP*UII&)3+{o5L z816S1kv0*>r#hFdD0`lrc4l-k9g?}ODVo`iGf>akOq|C`51la`=3JF~yI z&O>Op`JMgOo&%4W=91B5926Bax~tqz?^adA{L_OQKDDT?rWiQ3dqV(Rh&wK^seQti zDQ%BlJTr}jM~!EgB1m26-%AvD z2l|YVwUIVyx0Zy|>X6#0uXBx5^(1h(zi#I$@wF?f&z*e9EH&M-8ktwpz|BYTg&`#S zrp15)PZd4!?Ol*-Oij?1AE6BH7CM6xM~+hcqB7W{xQDn z|256o|61$+GdsBOziC!8e*f;|pcfQ&%3dJ5H->KM@WIT8^TyGds<{PfBiXP>7{%K= zOnHVtl;s@rX?O*=&>5tM&eP@F`Ez4qE#aVEZWVhYLAeC4G-ttsbA zq|-4tXi@R0q{ha2RHdD=UR#G0!6OL?9gORab9e;EcJR(Y6blOp9ONAgdvMm!3ziv8 z&BsqQaiR|xLB^Y>)%tdXUli%eYE}HGX)_Kiho^HKJGp^7-E@D>3(Iph%oijYd zJHD9^zb{3}6G0f#j4U&M+RLoCV6m707cQ5Urotz8P{mH>zw9k}U0`x+9Jg6$pu+G?0YCzvx9YIHub!Nv`|x|6x65SydQ_?>n`h`vhYv7_v);=BGtN`D44NXU}pAj)Q6d{JyJP&dS zifQ(hHMS+KL9KHt)zRhUM5Ws+UZ_t0#)!U?=8=TVQF=F{Rq;dOi2qz0-T*s6h|M9S zn~7aaOs+T73JusLy7E&Cxw5af(d&fsx08S|pkv@5MJnF8! zyt!T>i*b)Dfd~Otr9!Y8x}Xm;L%(q*s^8C?E^%*6*lzm@rlIAZ=G${|EqXSgF>lpF zoHg`~PRQTLlN*+5;7vWgMYwlqTi9v%To!i0P%BE|es5TjrtHR(|1LgMF?Xw?X2q?XXX7zm@w-^aaaI-bxuoWlK2bBfKW+lXwmF z#HHyV?`+FTNkP+fyP8I}D-fN6R#OiNzYTNB38S#~zSq02)nTk0&zx|+5nlpxbmJqm zqHWcMxb-o$2E)Sb-pvqDQPxw>n2J91$3IQ2P3sb#Ln>#r^(3uvBwYoCqNhLz?I^PJkK zV$YC^)O1rk1oNn8g2LEPH4|}Zmc702S$T6Sz=ie+X^+oUsH|&@-|O2fkQVVkE5H;5 zXdb5VE+}~E&;>-u+!?ht5ZM^Q{p02e*jd#){Ls5Orm&S1XRW7wd?YvFXbDtR0N~Ux zDlIHWo)DGxFv9)vKl7=#%mA0-p(M`SA)FHZM7;Xh&oZh!6gd?l>8EL2?r`u@;nP5W( z(m{{ZV=sVFgvY~_B_Rt>g3#o2)RlLkrmtB?Tb!_XnP+o97K#~etnA#L;b?*ypbw9A z&1BDsRptvugg#@@we5yJyNucr9?H*kRWxilQ*}F-J;^+~jKC8&PF%Ac34eY|`$+BJ zGm8Ism=5jX0nrs-Fv9mXg1PI4^=_75!}d^SPORc?+bBnvGW6NKZCB@H9om@KpT_i0 zqs3nUg8CO_ZuVEuPAi=H_b`GCmosUgleGH_LrdtZ0LFskLH3!TIK07^NRRCedt$;v zl|nHiKAwgd)}T#S(VcSt_%w6vc7NGh-5)OHaiJ#-`ey;V7iUe)|Zs`CLBu;}6C zO@T|ew5;&p{q>G}UeqENJ&9dCI~5`CzV>)F;*=~VeI4Fs&INfpJGiaNWB)_<;5v*{ zJA{a4qCJ;-%GOD)0JJB+x|%A5v!`6O_gk+H@MkGowA1>sZar0EbQzWnsuM&C;MH@f zmcYDq2@5B$CzY+EUt4p8VI^D4R$|P0bC@? zT-tq;Y)+f3Xf%TCH}(88*Pu|dDX=CQy?DdGz=|i?hElq?-Dv5uOu4sCvtIIfoMiD1 zp0i!CoouklUIS63cxbjU;)@(bC!TgS8U)UQ2Ukd{ zB6cq5-dvo>dMiE+8ywHiH~3Y-KY?(te;xqWW!hgWvJ8AFfkAZ4r!3e~IF!92%Exij z!#U#@eVtGxaRjrJ7J-oKP>WBwaLvc*i;gii8=1V0NAD@)`Rh!e!lfm2Yj0Hb1@JY6 z>~8G^knW`2K{))StupxNf;0EO#Ytsm(3fid17Gq4DA<6#?u6{3%({Z^^WUPK7l2he z3>bHM?*)JcVb&V3RkyZW@;&wkPqaLc@bJmkW=fy5tt%e*T+FaOoP<*7lJG2QkEugR z8su1IXvSK!&#JQ~+<2E)LkpZ6=(G?By4%Npx)G1kB>^p_j;TBRZxs|SjkwI4e*`wzL;5BYc(!cIKXVv%u_$E~d{b)7Xem61&B$LoCguD6t z?cxys=}J7xiR$ZvX4&C0?z?>G&ICi~PQbk?bYuHfYv2+zddr<^+bz+*S;i9`ioicM zn`kP(RcTCoyb67O_}THNNTR9oG*a0%2Kwni`^0bPNKLygyL_aaMtH1&R>~G#S|fGr z3OQl*#s4Ilfh=5TcxbipSy_sFX5=KBJOQ zA^FrlT}L+b`B=+Fgx;w`vl5XL|2zZLv0;P?>G%!r>hWjM{vKET(?S>4zj_8UV3l8<}_nP;zg&mhu$qwr=7hgRacK zIv;F;-spKNzF0sO@7ng?`e4c<5VQlnR$Ay#>!Iz-LWlaNeVyo|KLc{`02}DGJu=w* z+s@=xUEvn z3p!hDdN%gUhs>sm)o-xCse*rs9D$#%>tiG$^_%7C3@z7#)=O3w&M07kGd^#)$_277 zHN}(|4T_^V6oKfPXW)Bmu=o+Ys)*j^7fyOx6(b(>}LN*8KTeL9KE3L3pXV58IC*(kqRO z%#!!F8jcT;tn`6X$H{5yco!|Oqz+ufP}dQ~;Zz7r#i-^K_>iOa_xBgMKHo1@r-ogw z&gXb#bc`?bgx{hcbua(Pe0lS4I8t8IJ|lx&SZtoawv-PQ<9-w2_fm$^rS00NJQ38H zbI95Kfg2|;X>*>Wv!K;SqU8EJ1P`mn-H>eUNodl6!<_;qX+U zt-zZCet~x@Lp0``a5wPT7I$>?Str2I2iulnl=yz(4{IAG`sXpY5%Isd+YS%ZZf9=n z-o|zRML77B{x8BQv}OMHv+Lg9kT`uuc!&V}nj{F?DcMl&3=q5*Y-lf#GWoiL`PtmH zj?F4LpCr*ieE014>c04TAc`}{S9SxaWf)9*SF%3)-~Ty(tOv5H{Lw@TW5x?Ge1MN* z@YB$WU5M*4&{K$%DQTV(*UWDXU~wq)P==m%i6)YVlNux@1+kcrqZ@35ybiq7REc2)*nhb> zc)O~2q_}-da&NK#A3DZj;8k+ zmM`)pS`PxCyyVYv`4QnJU(}CEBx@&m&1@aS4&F+>HkrdV&~WxJ(}}?+KHBan3j+fwORbuZBAP zANpvW)ryDJiuA;mkPt)FSScCVVfuIXnIc!raHz94WU?EuKA8)55GH{}gB9ThGRsrt z|EEu^v}#-JixA`Mk6VEnaQi=iFYHBv`0u{N9lW1>OKSvAxGubQ?BIjQr+TP9ls>Rv zYY{=wql0vS-zPqVTiur)CM}A@ae`3j#ss`IwM@9QuZ_DRZzLLd`pAIw z=oa9uWEo^1g`!t`AwHlP>P#Lj;P;^!{8^5L`!U8hPDxc))gISR8Qmi#Y@1cgZi4cn z^Uq5_Jy{fYQ(N{;CQ8@QI)%^G*h-oDP1Qx48-e~yqz4S{#cJ5d z>BCLxv9?9*-YHk!bj&zygOQNjwB^WG_kkXDZw2Ya9pX6Om?Mhf_^bzc(}Fj6yu&jH zy!#*Et5ym0r%zqd~N)G>>KN$+H!0QV#Z1zi>CN1MP5{@P{Hfkw_A7 zxp`no1SJqNkDf&*f$1hGx+m}A0r!RNuK4bg=wP!9l*B4tJ9dRU$!PqqQyI5nEXDM7 zV=hbpOGn&ZsfxnhT`Mp>)Dt5{aw0`Y;?;g)G2g+kB#1K!OW~h(qe%(t5XAAPoA7w3 zFKwHhg7*Vgk9+Vp=UYDW3XjDYIYvcR+iH$89xv*g{oYj9-N7tlV4Z}bdGuK27oC2mRJsun;I znit`=VE>NOu#8gW+P=zIA$OzI;n0%Kr=m#wihgtEl0pgJZD~|m;?=95T82(=qp z;k{T@>38p#e3}L8&Q{FBuVacBC69})C)&IrldVBz)5N~4(0(-|CRAL8OX`8@Pv&!&={a4@?Qw$&T?^cvHk?aYDob|bno z%hhtx)I1UIiXgP6FdJXBJ$G>B4;2LMLkl+`wkb@ov4i%_FY!PNY+qGwjO)jhAn{?I zhad8p_1BGRg;_T}I4h2MSRR29F+J%q8zbssjUCcT{dQxy9$VOAO-#*7O8y$JP`No0 zM}-|aD0I6Qd5hde>Cg_;9oQO{D#QQyUSmIe3-W1>y~?eDU>7^+kyS^ys$n zt&YkncD>9T%PwU)UD2i*Kr!K?GK%2z8MvaPlTc5_ZIba_qbOJ9BkKV$ZYVBAB3y5A zf4zk(SGd$WTSFMY!cjsCfsObufI}10_Mdg=fyA^iY@ufk*<=?ZC=w~cB>cpfLssjP z$s-FdRsUHEW}~_xPq{VTi-{x|40;ufirASKi+dq_`#k{XAPBy_;WU+{Fd@$F`q4(ol5@!-g?t+D+m#@#sO8^b|MZc>ij!G3 zfu?Ac6abOyju~V_yqwq3wUsE!ppC zOV7@n@IBHp2QG2LFc%O5PYVrP=$m$g)66!cQ0v$gEaSRiUg#Dw0+5mw$no2-Z$#W# zq^R4e+WCi3Box2G3 zet0027Eq>c|1XPR4)E~Hp9Pyp^WvehX zm+XYca`VOf&U&~9*3J?nz!n(r!9sXEI)+aG`>+p`YUtg~2=(sD2_FFnZ3b2+W9I0& zv8E?S>9PkajyTY2E!J)5Mz9xHx7rCyNUqZqQI#w?T5s}UzlsZyesK(XU980>++0v#0vHU;(`zP`L#b%s^b#YqtGS!lG(#`=? z1XNmRDfiG7OEs{-Gfr!3;s;10g1Uq*_jwM{IGgtC2#_&U%wAa$-<32(l3aNFz)1e% zO)Hj|eW+@1I0bXmoQ|E8jnLSPJrrVmOq;;>6U#Wi7~TW%CRV95A>2nLt~#eSpyjI0 ztMwb7<;^icZO{r}!dSSl9N}b{WeSYq76!->*c#^Lo-HH+_QbD0j3sxu>@F ziBKqNZu)zu%Sw0qlhZF97Fbq{82GIKoZsjE6-foPjdZv;*~Qw=&+(C|5^@u3U1&J$p=|Tmf zWMLnrekUwhkm{{d=rdA!1$;2`A4v|Y*nrcYB+^OX6+^h<$gz zJ|a;FsvD?je9HimP{zgsDsIQu`51@^uu;1CQTk@KS$%}s#%=u5ipMq-o8nHS=xmxC z_xQ*5{73$%8TAUjMEb>yg>Sz?XYf)2JW^k(Q>2|B4_7%(NPie@iD#ED7{#k`>ovtz3MNYjeWF0& zsw=%Au^I0dWGFumn;T7Y=|azqs4@$$vF^$^tRil{Pqh{~i9)c`DM*jn9d$FoEVZ<5 z4Npuhnzk>N$;c28z?bYqiix04_|m!C^vC|p+U=lOK0xbJhzcrCJ0~ZK^-0&1uu=IP z$AzzMSs|#~Q?VuO)5wCs+`kys)x+pLuCL*tqcP?l!*Es}9JNuJ`X)ZQT;!-NX1jZc z9a)U}<`2)Kiy&XV7hJQ;z{_@UtMc5m;2DG0FAOd1kGt1*yqh-(w}1y5%e$@JVR=>O>P4dC)K;#djt(7!5&;> z>3y+|Aqi1t{oOL#$39%`|&~*-cVNa?-+nUI#o9!tB zW<`!L##rE=K)>x)Ft{AdRyT=prh#x@=frRDJGh_yz0dyu83|Cz2u>c0&J;NVMFhpu z1c$+s-P0kcF@*Q z6T8#_Ogg&I9S0_H)$kCsXTkfUE_A*Fpd_JzOGvk|hsccxtMK{veW5KED)U&(WN(mnH30bVy0JSo}FgD+r=_09|t)oZ2UfSZLfcm1cc z^|GDm=|~c$tq62(EhAQXACLlsZJ_kt`D>VmxQ>hbLIU^85Kk-L+s*V zPfng|nC;we7uK7KP(Ui&3rS|PbZNvFUpU}G3KlBVhbv+5LF8QeOx*yJ)x)ONbEYW% z(L==X&I~eJ@*4iFTZ^WZjepQGQtuHU+5NN!4p@7AiWtFFUDUH$yLt0#d>)cR2iKP0 z9G7qI5#TFylrImZuHV%pn}%3Ks-$2_NY$n{;c4B^pwh=wpZuf}R~p9q9-ETmA~KWL zB_Q1uZLhk%bcWip%0?pfU7aIa_7r=EVo(}s% z@NQ)2v6MHZkqTj&!vU{oDJ}bA+cn) zn9aSrSi;~^{tM!rs?ja{Z@68V(Z8aFqQBX`NIdP3B>!5iJHy-FR#V0} z!wzn1`FhU#^ZhKqdcFR_IN5t6JJJvwV{hyjN^ zn(Y=8n3EN~QR6U>)VSNvKy`WehaJ52^dXI7oF}=zPYsn_JonQlnERdA;BWhD`KEAX zU$)%L>V8~uoFy&A9Z;@amN`QVokUGDps94;ND;$_D7Ha<(5{Xj?r>w(e$lUY6nQ(^kzVgiza8QoI*7b@V{9BR*E}xQ+M2MXj=8x! zut3}6<>%4r{?SXqC^9#@+Ma0y+Ll~f8aW^PgX@UmXwF|>d%u(uf4v7QFTzf4MnZLg zgPl!}G#@=RF82+8?o0h+F^jt{8pTMtRx_FIjI!>+k8z}~F$!tdLag$<`b|X3_UIN? z;f4hvJZ*g+%lsVdzq8@nd$a+5ek^R?FGf}|O?W$78bjA9ds|r&)r~eFPqeILi7~uZ z3xKYL3U|wc5Yfc1%lRb$1|Vh9phJ++*XkiLHLkA;mA?vFo512AlEy9^Jw~Q?XK@LV zwXdZCiOJd5rqlbP+#rz}vZPVbrNNJXBP`Pf9pe2#C=-lGjM#nRV5 z*UclRhV$-U7zzj@1gZbF%K{*^VKNbKP4 zzu$;*3sa`Vi&WyQ<@7Dj#{@Wnn@U>c|2t~qm%BEZAfFJ{fiSwST;ZcXw$PZJv3zxQq^9=;9D9o6dFE=sdhHUHgwqbXc2EX360QJvstZMLW_Km=U|ShfQ+?Eb9Omif*a+2elyuYL!n*sG z^P&$0F@?eGm&rxC9KM4XG5Wz9O0Vt-x8)1kMmqH%%&-=8bVK%Tro3*5 zFans0_k{J5Ejki7eaJlW+tld6eg2yW{dgcfdhkox_Re>FrGG%_uGghg$uspDliuO) zmqcV7fPSLd?&y3Vf77EJW~d@j(|>gwyZ>j2QjipI`aaR8Wb78MUIv^rmedRb`fDSe z7vlCu#z48gu96?uDW*0H(?_8>TWmz7NE6h`xY>%RaL$jN_tGcbP&2vV`gqMXiZn%dVa*^+1=2*c(8#-2<@Xu+;4H4GO*+suL%}Aq zj8D2il}u8N`u-9K2SdHl4bKf3%87j&4EUGV?GE*}(3?H5_psS+5P`rW4*SR;)Q^B) z<(DP0rt5)mUPOq5O(NwMUKePr{W$d|0)J>}t%9{}4J^ZBh82`2GP6w_IV8r+R8hHU zgg<|EgIKYtgH6{kmx8J-(;;aO>SzyJOihv^x~u3-?{DL26juk$PdbO*u^Y?g!mb+)Z;77L0s9VV%?Y)e8L1 z%mZ(zh4&6&;BKOEMLc#=+>HeG;#te1_EOAHPfedT;F`2vL9>QpIJLyTZ$n9YwNv!# zeaprv2Bv1_`A|(hnOpW=@8@^k)l+PMGK5M~esh0U&APAmTX5Q3=Fw|f`-im*w(aQ_ zJr_4Z+RShJ9R8wSzIKdGl#O$9gZeAc$U=C{{8`j7_GX%KTRgOd@e0|w%z_0fw3QDm z&k}5o*ozxa2!=#o%%@#-YNx94Q!~2fU*_zz2`ObGSQ9xD=1Y+5?)&f_=x>7A`uu!n zk+o*7hWV%WBvH*rz?~b?r73s#VZE(6`mP@o%hfd5lw4?Rw3B?N`AF?YiUdeEWXAt* zD*>nWy-*sR8WWMvMvRl__`co1t#+(540#{Fovj{iL+u?{35wUD-8Y`2rgx0%NaJyg zWEXs7(Z_Dry~A9cU$TL>f|W*w#s_+vioeO`U`CkCyk{IrKQ<}r#Mfs|4v6=VE0Oug ze#jOBX?TD&BO*7tfNaYQxv_Y>8(E8e82;o_ekp=-fC5W$5Rqu=IjFObSY)yBzCnq% z2$yU_L_{6GRkL~=;4a+gYI!inlg>^ohS}(?$-o3a& zsU$?xg`S;flT!^myx7gl@m(!@hU{t^bbgYa8XEoN{DGAv?Xs!VCgVu8beN-LZy)*K z6Su$|_SLY+&u_Giq4$U2Y>V{A!G|ZEQdr|gw-{jXfU;Hn1{kzdQD)s2ZI&)GG!Hco zew@KXu$g$6`1-s15-0WFjjm1KCnbZtGyh~g2oE?7wA<)!i`OjNow0F^G=t4I%;ezJ zRMxr)-g#(0a9b~MwTuy56(oIJ{gHnwXsq?9C52||NHs|TC}xb-Jn&{NfF_`P5R?9x z)*1sn{K1Os;1S9xV~Lo(=kka?U0`%v(Ug{+m#s9}@w^wyv2z(yg+?01UAR!P>kh^U zpteHRsmkkN9tO~TW^?PJEcWcaVh>#e=y{OxEyVW!G zlS~A$<59?19YKZYMHiv2g_k40+Xzc-ik$V=fgOa$FYsOEO6i`Ctp}lDhM+U*+PJhl z5z^p?ipq(WFOxLExZDO$7&<&w)+anwY{-_WD+_)2aD%C69m9Jzk5(*HSZOfG$MP3q$|}^B>e+&E`3O%d z0!2GQh@woCb=*#tSo@0aeGNWJ4bUAyI%*%??KwV1=((ajTmp-pO!Fkhd5T+|##o()-Py@Kr+n&f&IH z(^}sjXFCsvkzQwxy8U1`)rYoS>R3D6=DG0V$*Vm4j{ND&0rJO5Sn57fia`PN6jM{lV%*9z=3NO#-BU(j(Mg+NoO_H#vQMUG ztI0G_=4r27Xr0)Vd`E7MOWrOYxpK-{@u3<*A0cl~EeP^2sZ>i*ZmG!NKTz1Nw23M| zPvI{Q!dQ^u{u^E;ULKOrVV!V*QNlz-0JxU-!yA#4iQh9D+8xa~rpWPAPkKlZ{sO})D`9|;Mfpg3-SWw2mYXx2=Bvi4M zEW`idG)nxWyAe3D(G5|Qs}i_1#54cpmES~Wl;AN z`bikspQ(*4+r?f2e*F$JBFs2ak3fzo+h}ktY155+OkF+5rzySorPHBJocD@Q#dOM> zgFarDT-0jW2~gVxKCQUz*8HpY%=+3fvj})fU`1zKoZlRX9rU>o$K4o<51!5Hzd1K? z>(Lmpf_$6!tZ7n*krU6MDxnfj6*u+AT{_=jId&&lZAq|ZUv8u%dc+hoEu04ZK9Q{U zo)%eeyQ@T_jpHNR_?U~Pa$3C?`+TC8IYXUIBi-_|QZSmcUXxdBVzVJid{j`E*tAK? z@>9}=zgQfx-r)?GnuTa&SXCLRHHt#b;Eu7^b6^WJqBJwiT$T>CV%4knnY^w?T(IE2 zx7uHPQxwGkerU!=!iRysp)K*bHjsS-p}rX@j#4bKmx_=JguSp^Y*WUZT;tnhp>C@; z2-Y{5b--HxIVsx+E^=`7*XEoV39+j!b6fd_am>mr*KOn5rh*ZKgzigWz?F6Wml%F* z2PfnzDxuVnFx`oZaV@;hT7iM5`5xdknjAGf>F(wb0^mJ99w8|;RniG?GazgS@>5N* z)Mg>{p_N|QgW*Gd;Nvu?J%xUYxznIT?q_wmF<-HAKs23hk8aEkJ?RYSe#E!3EXAR7 zf0NFfBo|6t<)Poq=MK?}K7a*{!o^MgaPZMtpaay=1HQ9n@?{16xTSrSY1rQYCmqFmP_`XN=1fHab&f%1dj1o)Z%o zA!iWqxKJOuy)YKu*P;HsguBj=+AmlH{r)BMaBH8K9{(PQKOgl@- z8sxcfzI{Y5kQOVf=w5iLgt~vAu)Ge#?7_Ktyo@ME8}F>&UvSdo=5XqGy#0S zU8W}Wu4ZGFeUL{*ag>&?R|quY1`IVkGQW{abkjTXW^U(15L2vs&;6dmT~%!^-2$68 zxPC)JuRdpYTE-x-T_XdjhO$e7M0M4hUld{Y4ScRzlu#g^KxJ<*b*D;$W!S+#`dolFaDK{dZuJo2xY zI&S>XkfLqCYpw(a%ZhoS)Lvla2PsWaD;aDNfaQoHqQP)bw*wn4*9gi3l%s zvv+zRk|D*)<2r>?;%V`Kx1CRQ*pxBHH$C>GnE;vM9X2cxH?*s@|A^`)*@QQ5&zyBf zD?M?W@-l|)8yqI`Rkw+@==;I11m=0?D}zZd%(<8{k|PXKS~&2Jq2B3dvp^dX=hp)s-dTz$!Nd7;)vtsg4H zf=#i6u|~FWOFVX8^j#?>D&)gUu$-v8l^{}z9m1f4?Ew`+@}&`fm(zeXiu}M^W3IUz z1_$@DO>4Xr*(lWr^U6;_9v7DksTbK>-g*=QDf)Fb>nO&m!`ijr1f?W8iq`%Of-8gd zyb5m4gG6l&8DA)(9BMnJMy!$Cqyn?PpfuXnu?~HDzJM-*e1JEeqq087W_p*+5cE7l zjFobRwO=#<#xi)ZCMzeW-|4C=4`0xy|F41;^j`%n5ZTkakA?caVvUQb&T7Zhy?qpl`gb}|-e0kX?^Qp9qGchExOd9O5bj7y^jt^=VG{;*nHCy9EZ0bhs z)o$AHY{0IuDLh^u5|=*e9qRs`Gk9e!8F`cKbeFtU5>T(lVFZcaS5qa;RI&>A3RRbK4f{`UeDd7 zy&jfLJyEZDVX8^74ZoS~zf6GJuPU=u;>V9XJf^=?iAT<)R}5pI`?&i|4bswHU`sML zZKa7=4Lpb)mxgc9w?mU%9_p1+!){Ac_IXPt2_uvF8NUm$;+f)wKEqltM90HdG(rV+6 zd%)!zCOte#*+d7Cs7gQk+7>Q;b8f8jri8^sFyxV`w|XB~&&0%Z+p1DVw>or=-B7#w zQh9)!3n~UnOlniWC(SscIr*$r%{mq{^S6Nnclhze_$Q@qkNRQ^P##OqV1zj=K%pA{ zYRYLptabnR0C)9Tpo40$-Ds#oQ}6~&S4^s|h#@`Brj6L;m*7x!M`qaJpC&AYBlSR_ z9EYBlyy^aJ&0*jepVMs~wmdq`dTm0j_Eb<#fzj`~&m`-{Ed3$%gsD*fPq2zbTPME; zjt~8kAj_ER@g7yTfPxmuM#hKeO|XD@>DuOObvKkjj*)2Ku;*~nD4+Bto14gKb93$y z(uHbu=1tfE%UC@T?HB>#jX)o00BflCpk;Ja9bI)XFFh`&kSUua*%5ladyGVVuri-*mMHnGeAyP3MNnIGw35=n*S-YK!pA9@q=5L%|uB z(>rfER8_;SB*tJ*o~rS67mhk;O4nI}@f|K1{NWHawzb@kjRvrB%3%Dys;!A??PD6= zZyQWba!XOQRwwkCoS4od=UC_eb>*Ds)}MjsmWu?^vU4+<;ILkZ(%2*%d)HlWRr0Be zpo_mUZf{u#i83oBaR#2Z?4vjF%TZdznQ4*67!!D5zf%zT4%1M`CpY=`!h$a5Y{+;W zUXi@f21YadW6M{#_?I_E#$;rUQqHC#=x3Fi{_Hx%3SlIa4*$~t&AL6a8 zx-!pfuV^aPQf9zL@Anna)`w@@PL2MW8jxBGa0ydu`E&|d>Kt#`&3ow5^*(}%BDb-A z8vpZtg7_+uP>SOk@les)QT@)}UP>)u{3JMBr;JY2hEkZl46=|)U_qlWw0_i~22|RP zk_*u^>Nvl`@h%mN_!l)4K6{$+MqhG#gvKvbt`O4}%k%=zuo=l!{dH0_n>jPO(;d4K zABNe3CoAvV^r>l!+I|r*cOsm|Y41KtnWkZQZt&9-b2sg$_%B@0*hsE1x*)BG;7CrB zhOF1hW~=9V0Im0B&A5i8@FIk7&s8vGJ2?vhx3$;PpV4zIA;`AA#kX5TMW$gH;;kFq z0=X7~Yx_3qF(r(62w^RJi^JuHc5{ARKNz4V?Uh>+I}&pmgiObf4-a;^>^u@f|M>HYX=2te!cc@yCYHj1t! zAOP@RKu@Ls@>xhxrtW-m1&~=Tr-vSX{ZC>g^IywoM@GG{D-Sa!wFSMKEcsN+2%Cdf zZ0rov?}P+e5wm~bc7Oy2oTp_e+d2#wAjcvu|w z=yd%$uY$=$IVKZX?ZhvSf{5v(?Q#aCk-6y{2}IH8F-RkXnk6H2)7|mrrjW$iis5~x z#wpIQT-BNNSIRm|8zw%bex82A7jUsMN!d!UYF~ineA>K^^88Y#@!?~xm)NYfa(eHo z2YusdPV~i&=XvrPs+5X_WO<5g`x+&2*@!$iDxU|+>Hoozy{(gsq7}hgFfR$-Q-!1qOKsVoBr|Hmf)4zrNuxsf|8rb}PItpcUKtb~3!$Agal`rvA0X_gN&zx6)n z9y$3werijx@-U%&bu{~nWGA$@S9K+Uw0z)?gYd%;;dD(HX*fK4y*yx4w z=cRDXpzgEktuGoW&?Q%S4>_DOI>cg&t>YQ)2`I534f4irzx((;)MuWae8L z{-YS8y}InYFA*GPTh_`;>?{}vC#>0_yR{d4P?e8ZL=<~aE$FWxAohx?lsyQw?wWkH^cV1 zVNrL*Lksa>tXnml)Q9hCwZBVM<*yK0+Bku@auyy`HBV{&dO&D35R*;=Q1wEXRlal-XR$!pg5=Wn~YuBEr?Lu&{>T21NvkPs|kBy+4?oh)NwI8XB3B~^b3c7?tAas@FM>fFkM$X=?RTXp1A&TXC-O$ zh)?Y(>Sd|NxI=Wbl5U^pZU?f@-9FS}!W9>G#tMF9#i()k$svCf8XNfFhvwA19;h{N z&(q4UA1BA@9D6pO4=v+Pl>Ee z1?4xqL3eXn#|5Ww>Oy^Jn^j5uH3DEM+++d;AQ#sl#xfShqT>NUagP=;UN6hX3(GDw zSlOw1VIm*l00m2YknLtEIcqwTxLKA;{r3U!uq`reCQ>=`nwBiZKlw~#S*S1RyFP>b zxB5|r?9W+!)eahKBjXiCag>>U&!bo8py}JPMeMIIl4z~HJ;k-`)!YT87igikEWj#R zOSZ&|I@uHOP&I^MoY<{`68br3+N1wis}&cSp1XvIF2D7CRD}Cbejtm$7d8NtAtWJV zT{hZ=4%s8(s7|(WK`>s?mJz`Yfj~4BzIipuKUu0X7#Te>?L3#OCG1>I>qIX|^R5g{ zLKqmWqvC+7x5Rl>1~ZUB044BBADl1?Hz!eI-ts=7SFIWSq*xZ02-mqMTUU7dQ;5p6 z&Y!SkO>;yGe=^}0v|1-TA-Pz0V#7^R1xBduOKJW=Q9Wp09xg0v_Ik~@XO}CeA+r`- z|5mW8C6@)hx#Lo2USjpny=%Ddw`er1Sg#mmACnV@i}H&nsk)LzD#qPXfS2b$O#WsO zBZtSY$F4xipM6ql-MNp>V)9JQ6y?fydubd};%Nv7Gf%FXJx5K2?@d%8G zw;H_Pi<;7BgquPt&sPiAD||rD+KoA+3~Tg&Mx+o;VPur-=KZ<%Zf$s_Wome(+Z*ZXc#g>*Cu+upQ;JT|h@HXJxogEbWhSSuj%s zIdm`p&7x17#sU5MZ;NJdm+6_-{cDIJ5rc-Z8JbP&RX3K6il!h*%Yi4 zW`x&9>K!_+lMo>vsK%;vja13AKBABn%pJDV=DB3gF9hc=^B``%rPJv(SLD`~8H#J? zFDcf5riZr&ZcUtKXOW=~cD8RQ+S0W?sY3wksS4oIfOzml_Pb5UP7- z$LZUph{m(9d8c`IuDH-XK$7|GY8Y+`$TtJy>x+Wes!8-g>!c*;y*c?G8I>jo0S8k;rfOiT)PVllWeNW4~dL zL|R1d0ZO;LPqsTdSb%H>rmXJ?aXv>=22gp@6vk{eD+V6!DOdoAXN?Y>j02Taz1Z7o zxvHa0wmfFASl)GWtJ3+rFZ=rAi>UK`y}nXUo8xUdNYt{a{i%oG$+1Z;Y$692X1FCtuOii)tcIj=|oUnA(t}_T|Ir z&fr4i#H|q;KMpx@fVkH2Kbn5`BA(NH$S7aV4sx(0DU}V~l7Y8LLBQc{L1J0!=z)EN zb3;+q)Q7E6QB@Hmrh^`BiFL24gAe*1o?>z`4JgLRfkzTuzFg@)s49?gJe&c3U_ z&vv4G^~3}J|JRfD_<#L+mVZ!al<+*(gBcJ{v>!6Ig4P~$URb={=ijM!`7h8-rmXWO zB`NnkZ0^r{8XElao~DW9Vns*%20E;gUoZ(v?CP}x9Ly-lXtROCZk$+z{TDvjcuky- z@$n4>Y=}i*fYXaZwh2TObx+)2<&wi=0~@b9Mr@erh2QBeAGMe3TSX#?(-4l=^CH|; zDJEEvT&ul+t-IGUMTL_5RDncM9D}zQ(DSjPVys$eRP_THLXxz8?&pech z3wq1yttLvz=g1L^bg;?i<-AGDd}Fde6fH%xZR7NY!vqBvDjweN^jVg-)$WM3N~yCR z+hxsrwR|O1m{V#ZI_51@9B9OcrOmeev8O6&Rg5;?N5TGkNqBH~YGUMzUE-TK=~ekX zp*C~@ZVyS99D6OZBA;Y|XXl8uH=s$>@7zl0>ruj&T;>~|hDWg=ER zhk>H;IB)qtlcJgI&bgd|ypept$hDGyAGtbJ0A%@n+E0@tW;)Qg%JJJOXKmkgp@-t1 zDzhCC2dVbV7XAPpNVW{K9H^lRualT$uo979LkGI_wzvQlsig>p{xXZA^9C z>j81%w<*myVhyN>a)fzxDH6|!wuv?nA9h9c&|6PEtgu<2hHKvNF}~BP!e2b%eWk9* zrq7oJc8aVXNk|!WB>()3i|(q&6iNH#;u&JkjLH*$W8cAFu4QS-D$mb>I>E|uz(vh6 z{MY|CVyb1h7B$xotiyiM-)%CqLV`kj#fRbZyBf(b$W$a(<43nK#$IYjm+XV)$Tp}t z#y>@fGF}So`n9))`2A^In>i~U96ata>4Z+ligP^*@~Wl#{$4h|&10w6VD=Cl zOchwKW6-y9$+4FDenLNvnR23KctO_iW8~-rLLIdl4vcBSi(bHw8+Ky67Z*jVALqoL z(586c^7%-k6f0QYglR*Y|FTGJ?=7gqZyu_M_g6no*E|wR+7jwl<9~rJy>$1wdBmd# z&G`YCJuqSK2IENou&)gx-nUw8!!Ih!S!&5^+*}wY0vv=0)#U(pb9A@|h=8+l)xE7e zb|(wQopCHtdq3+@QC}Pg9D{Mh2N(hJ+}vCW zZGXNj1@NRt?iP1p2~r&vRMEM8X^bU>pJWc+{ORiIZL3Mq=HYmihJKKkZSYV_WSwE` zrNLdz6)IS7s1?68PjV}acIVX@PPBxypE1UmdI=eaITyG0MbvBXwO<`pne@Z+lu9aS zdpjg*?y%20+R3N-<{{cLq+^4x!L7!4>ll(IkTiiTXThz(!S&G z!@o=W>@Iv;OYsf)-=x^$A6-kowTYK_mF$3@r1@U5TjIl$*i=li+I=nf--K#pArP0y zyBx%8mF4?M!YklF2)O_U!tVn6Ci6X!KxO4_Rq~?$5oD~l%g|}Jxg8^;F(dH8&)af*)Esoj3@=OaRhW9-Ms4hT z+Exlt$dF>BVy>;oGFFb-^1c%yBcWEx^#W=)2wjX9FJ_HgsUi{!IAbWNyBO8lCQoL6 zWU@?xGozN&K0I(izv%r)vDFnMx~2h4Y_QB;)uKuEGpYJ%lB#%I4^-avdSkXkt zac!(w8ek<{8uxyJMl~tSywcPtKcWZTUL>rAhSxhy7mdH-4$FW{Sw$0zfr_7-uvY^v{mRK%R_s7JcxzWZ&->LOKQs9zlmgKY{sK>#c4p zx3#$=3vJkR`B&s8!Zd)``@0gzp%=5DvxA}i3eMWdT<{<%s9?8NnYr|ON->epeNj{5vF^ z6v=Y>DM$4xlM&k1g`QRxZkW?CD0c-G6@^NCqTsJ4HRg#rq@b?OG}BRQ%wcWwsNRn0 z9+m5CKHJM-qff){{9+vBuEt+Xff!LMgvQ9`cZ=u2RarLIhB>`7{Zo$kGh2KF=r`S~ zFTNH%J5BtdXs!$8;HA=84Y|r2-jZ^wi-*|?8g|}3T8{VZm+r*n*1Io+;Pp6xE{Abm zelx~3XQu^>ek|RPBjnUajd!vT(H8Li!{Mb$dt2I>By2pTA2BHWW~^qxH(|yFef9Y7 z3C{+pFNS-6qt#f@_5eB8c5suSD0Agl`(ysd=Qi;%fw3K-6Jcf~T1bxO)tA&#Fn^cr zTFLgQe7+&~aZLz<8qA(&i^!d_~zZ>uPN}Tqd(T&HbTUsO8 znQ=1RrFF0S_e1VRPbKj@b8?DY?#D^Cx7Tva{KgRy3rE0o-LTKM#^}uT3?2L^>25eK zd;t`>Wc~}C5Ooir5ePKsCa_+M{PFdWZIxJ*0!608K`+DT|JT-6hSjkw?Jfd=1WgF; z!9771t|7R)ySuwABuIdTySqzpx8Uv)++Bh@xwH4T&$;J(_x_n@p6==CnyP-Pd%COM zmv=T!)luE1aW&jvRYK3s?#-BB&Std;&qPwics#v~(w>)l|B;aFuGWB3VkCL#k1l?V zJV`5m&HHd)9d?#vLPm z>N#`C^a!5WuXr+d+cIvo@5MQ)MMvAy+72dz3j~56Ig{?m7OMD49eMJ%y*@11(MI7H zYsVXDJsM5>C*#0U)u}2Y&4pDh)H3$}O27<+g^eYlqUS{et9@SSg)!()>2$)xIz=@{ zmaUfn!Gys603S3H|4g!VYP4ZI9dQ`Qt|tws#p9iy(8w^c-QyUp>TmP7eWE!+yMKv& zd5%2KmHElzw%lahllCzG13)s#04|va{AUY)KzJT7TYWnNT>wetG5_zvGXA{>*~c6H zzOXDF$I-4a&Prfe7Ramr2gE-GS*v6>vv(z3jauOx(Ae@7R!4 zgN$Eg-X24sbXh#%hwO0MLu=_Cd;ckv0MGZjx6Ar@27Gl#9sk|6pRKWgOBL=F_~;Ml zKWdPyuh#uW9ki0*j4#Ow23C{fG^Xm1qc)0m$$bWq!V_@;t zpu>5w5;^P);tlbg;vDhiGu87W@s@~R@z#+m=^nC($NUe-&&>Z&4D0)CX5D2;?1ZWI zajA2i(jqVNdjCFi>(fPXd2z%c5UVzt;s9G{g4j7*?>13#$%IgDa}4-r6*1J659!V) zgX{-;A~v6taA+1R#Sc4(yu!VjdYbIg9$jsV;WYSY?RVkvzm%QWvVcZxSxkI}Zn?n! zAFAU8o&N#V4TjtBr+Xs&TXms~zUSSHGGhN9)BHT>T+bY93x2`44u1Lr0wU~v`tPQj zVF%DJ)W+mPGvzD+_UZrV;BVWgGR&#)L;tsJ!X8D|!`-2g zLj9CC-i*i?T)ZV5Mi8r^@e=3ct>MYW)6?tIel4*AH?qIYlos(UENyCKQz+K2qL+JU z9%A0?xOQDBTyy7X2(D&rFybWe%?+(IX6*A>Bpl~{)XH4tPxml=qJwh3K=5DgF+6Ep z(5zASl2e<|RQ`9(&0jSsVfXVm(LuaNQym+c2WyviYqtZrzhr-Snz`~!tV(!P0lLqz+20mp}K2$C+t!9M90AX0?~ZA!|>$%TLtm=na82m zx_2rXw*S1x;g}y#K>S;#1=Ui>&#Gz29stO2<|kU?9kAHlw1fQERyx5WzRoR8C*fJ@ z8B6zZN308WE;dY~7OcQL$ubV=8Ysr%8BofItEG2q0`eW=m(6?->f5W;Z7O&g^xnv^ z_^A{R9!95@V&>m@|k3A`-ODv`X{V_Jdozb$A9*gf?iwY1@XK=}E8fA%L+z_f1JHr`n|6Fme^e>~qr|iT zK*suen3y~w)I~ZDjeLuUKbX%%DKh}jAX*qmaFyx8CkGy1m156aL9wC#!^oQp>?cS4 z#nY?11FSqFq%n}0&9Yt(Ug!P$lH-emr!T0jce?UZzkVZ)SD?WY5 zU)1cxee6AS=54_K?QsBV=I1^BS{EJUTW9 zGVo>TrGKN=JE|4Wi)^Rz>!qmCLy8MEOcP&0dMHRLQ&FZD z#7;FcT_x}~;hObW?v-(yG&ctF-xVb9J{+1aWWg;99Gcl+WLn_NZ3gibgxm{r^a+!n z`jv8{4tu1?QG;H$NmFBhUac0Sy#^V%6}|{B<-YWC_X@qx40qhp(RRRs^Uulpa{;C? z0BHdrE!Ew(gdV$;kE=~H({FUWz{pPZKVa~mBA!L`XEF@C+X6DllLB%uQo}H>{w-!h zaegm|QWFqU3=Zhig&OI5?BAjRU2|iAl5G_wL*YpbL>0xcM+*Hv%W8nDA%^XjMav&TGn70WI z%q0C)`ic|q+ABNta2nsNZ8QU24l-|pUsWz^_PWA$b0I+c5L5INruk+xwoYuy9N;`< ze{TT?jqg^6vn(=yxyp>&bf%hX77;J=Px*3fN0Z}SZQ<~wFEHl>`Zv094KPSeDg=#&HiyuH$IeyaeRUilAGgTSFo0P$+M zN)OTW?~WH?YW}0LjtAWQI{>bbG>-UsPWKskPO8@Bsx}CqiWLT7zHbVbm^3<-kL?9X z$|t`CCXZhRY(*28ztb5@w69!ce0Np zT;nGj6C4 zHB3_BrY~-2ECy&q@mv6xl; zh|B=Xd>_G~z;Zt`_zQ13%wZv;Pi?mjn(MuejP9|UC?#N(sP|gQLCIcyI;F0{QYS8U zx5^OXH=SZqgs57sDWo}G<6rf1wbj4LjY#2xsHi>vV{}PjF$Y! zD;&r&CS@G|1hDqDQ+0CTplGthq+HB-bFN^WWu)*Kiqs%DnCEPc`~y`gIzjN1+ebN2 zqcw%4AP~*Ga-B!-{uhwstTx^VHqNpo-4qrjRZ}H!y{Zs1$R9nka#iZ;PMFkK4loB1 z^%yB#^zY=ovCpJOld-IT>bPC91iu$nEFLBlW0JA@*!>RGe$?8zY25V3vQiT7{=2^v zjBHpZN#tvi?Cgx0rXUk9WFyn#l&EZx>p}J zZ{^p#LR^}nSVySOs-k31a^DPFS$W(4po56?%6yxfdfO)R!z&vL#tTEL{p^-j0{IRE%2gRXEIoB23iR zNdUvm{<@D=0t9%E(7!JN7>Y0vy_{rpYlOwi4+RESgS?ji2Ny?- z$ps<%FaI$AyXD{hA@5E-HoX}VLnBW;26twxe}BNjx=yQA0%iESH_qkpxiZ; z{wmS#Dt?P=6jqo`*mVQMT$P_@JamLJ;3=V)7G%enwifzbZ}I0_sY~JpSyqJhb4S{T z!nuj|uw>%s65j->oX2I@hE(i~eT{7OMg!*wlNJEJ^&_|L!)YzRx{U$AVOAVH9sZBw zFa#V&exQ$_b_35NJgs!?DHhq1f_2V@0dF01&xjjFMNrN!)Wv3cBf98-yO~F^Te@G8 zRkGHGcL@(_uxgq$qCi|~z@2J44ObN-E9C>GGq2#sx$*@e@(j6;DP(D$1WqM4E)(Us z%p?AvIhRodpYFk3IuZ`DCA<R%WiS@IqgaIwImreyY7dPTRR3RSbNUHAi8q~Y9L z*SGL>5wv4dA`Zi;ssUhu0i4qMlXn99u13xkkye@A#eKi0!)mWMv~{S$8Ai!y zoG6eq9TvOjE*klkRZ8Cz745jW>xwu$bYP*J3Y2D0)@pjQib4B+f3_5x4X^#mr4Uip zJH>M2%Z+Fg@v9YdTX5UxWLC8az}I2S?mfm+h4t|;xxehtMy?nEELxl2U+P|hz5xL$ zz*<79(y+?LlL&BPB!ClBN|E>j$`X88uUI&?6SvUp6|YYptRe!;phYh~gCqn>=Wt8< z%=#=NwpnBMtfhzMF;pVMm5$>|J<96zZaMG4V)-M^Qf%nfm9~PV+fF^1!}~r@)p+WK zn_QDqC)Lm1&0-ZoCt~+xP2lgY5Z6L;V3n&q&iGZpK_wrDP@w@G|3V!%*7(J78?(1pNY^~grqC6D!X2F7jol=W5D}r?D{KOOmbzN3p z(*uslJ7-A2w|z73(I7`=DdneZb|LCbppp@_FIhpM9G%KgWdpqo zol>g(jZP?9J_&iaHB=BrHmgt(RMJp0Y-`g8hc4v?ecQ>f=u6R<&X8?xL_!-LBSVME zO|u`@0ykJ?m6zh`W0kt^Tqs}bh&Pgm8dr~VbJUP&gf)o=Dvfy-tA9K5+IdywnO255& zLq>AL3KMMkX`v!xpk0q_%KhoBbr>Aslm-P`&C*T<(-tsaXt>zh)4cCf@zoX+^IrV8ptO8cjf32AR zXU+rBa%$yl!0*{$$dPiFE|k9tUY=rKfx@eIAgVDWV42!xm^_YoJOd#*aMnd5OUWPW zBOl8XZToRxuS;iOB*afoL=HZxZxKpqef|UT&FDzi&O&vuOe(M3YOA-tQ7zw*$-Ps% zqMBYUJL2M6>H63}%1Mr+KdE&y4xlbdc%3i}5`5z%^nhHKM5H_Y0hu;7Bt>$3CFfU+ zf<$ErtL@UQ4+lW=zAlM+5yfL=c|qAMt-NY9u~0HnJtF6?vrNO?>RgWK`43WvpIma! z8K!g&D_IS9erDZv_F0O>^tnM^z9eM+#n+M8y>@{KfZ$YqdvVq^fA@mrsYG{yn1Ffz z+Ie~NN2-N3Fg4)F!E_a~O*LK}Q*Od@t555{nv^im6<)EnG|;YuZCgLoS$h z;#Q@XxWk!$>pk4m*i`zS+6#+O>*Z@QhI=x9b(X^ve|ydGU^(g!fH)KK^%T3f|Zi=r+o~+w!zIda*tn3~4oPu53p#Y%% z!lh^wxUce;E1#pi>uF=2wF{Nm9+NJ`T@?ERk_lVBl%5vX4CK3ZTnIU}cAHxk-SFIM z$x-rQ*7KtCj1j*$N^jCsvC73?A?uEPwtPl;``0v*f2^>4(?$a6X-B(#EuQ!3NwWM8 zXc==BJG+;n*cOv<(G@euTCRy9bZqWZ!YY4F_DBTh4MyH^g%;V!aWu4e<+)%ZgrD?x zorI89Y$ntGmz2;0dQ;e4%;{T% z)3hR8+c-z<88dcqaT`139t!2&*rhkhrmEbo*)YL|>LTXHY5jKL16&-K_R@-)*ssm(gU0@m@V?RFAq25Da-Y0Gbvuku=mxP%45zh%1J={^V2E(P)3MI7u*cjoWNZ-y1 z+n(Pw-*HuY^j_})xoT?Bl7Z`d8;EWC#yEuldhgFao-dhaLBksuy5u!x2#2*!b1>?d zl8(3m67=(%s&RsfAl){;`;ODIq$3LK`>xKw6bggDTgYwIsG@W65g5@20?1IM9yoGP z;pu~?V!3|^MlpijAX!rZS?IZH_dYmz^Z2W+YFPyGPc!03mJP z2hSVZX|XcFK;TQ6FM59i=Cbg{5rjwDkvx$ZHoEovUg4GBxk_}-f@VL&t(ZnW^3%uilKSeQHhxwh%1ayDgO%UAGC}kgP>4q@~idk*+c%DG5awVeU(cf(N1V7X27oyiq5q&QMEq3i(#@kW!qZ zwX$+hl`gT8)zmkyLTC=NDR~atgXVtI(ZInl9{&WpgmH#r{zJ{J;D%24dD6Sn1@1yv zo~rE56yQDkmf3yeFjQ<;XcX{+7j4}#?#{!scFJ~Jith~raRYANkf@-odqYpO#s@zZ z^rPo_L`V|ZMVT(97{j$CgTm-q>N~hQ)9!FB;JOkf^fO34gf~yUByJ zsGPH04Gdo9m(r!nEHny$Ls0Z_MtpRg|3oFs0Qc|(VM?TG-SPEEs`aqsj^7hQ<5T~6 z;S=ACUf540o8IsoQ-y%y;Sc@mmh4StS4f`?YQsnDPQtfu%Z9yGDCbpWlXWC(aH6b< zTfVz@jyv{}uv6!K2L2;xLjIX#;1FT4PTlhVHRA}7Q9IMFpke%Czn9*5M~qwwI}g{uDIN zT|M7yM2b6L#wR||ym+uw>p+?mfy>L9#&8XqMLaTr0JCZ7TsMnMbF2K}Ug0h;Xs)np z00DNNc9W8&JHVyNh}TB(ewr2cD_4A3xLq+<73=B)(#&m9`2tF-GVdpK$t*S+?q+l} zu-4^5rZLI`cDS+3o87#aBa1|5SXf?*xSd(tX1o0{#om<3s&5*a^^{RwBm9AuQL|_G za;+)HSFRj4gryVt(iHAC_Dbn1IahDKOWM#{$C!mqqkDn!JwZ5g3e--YZh z$k6mxtDW%!D1M^0*dg3MS+bMP$PA0jVpxd#BZ^7jZP(r>+kT|(|5z+dBhou~z9+?nqH9oq1!KT+=p_p|}$-_NzP@FB6}${p$arsaw?m$($f8dlCRZYR+YDTdOs7QWnaUm-C#34 z%{54M95)V>D%AJcyiZ8{Y|ThC_iU65XZ=!P2H?D}uC8Xx#&22DY}r>;=KYlc@ah~q&Ws(h>t8ayA#I|*zTc=?A{NU;z8gc(*O0;TNLDOboR~Og zsbOf2;)+!QqbM+_e{>5d77$!!%FLIdPgdBqRKRDX5IR#Z&%%cmHV%};PRNs=w^w4{ zfJbA#ehn1Stf!;u+d^Y+m4Ha}mZK=P@C>@e$Y$pYCTn@(1y^~?x=xYk7iv1~GATI` zT2R&bgOR&1kWY9mn)1Hi$R`wR^-cPV6{oCM>&NL$D<&!l6{3b;I(pb#zGb>| zDU)pJI9|hATN=K|ZCA-@UGfpRb1SN3F<;DL4#GtU-D&z|==h~+VVfu*rJJi(g{@j| zXW(S9STSDIh(-wtc41MR32=+qRO(VD@$qlEGhwb7kkMX<)oEkR#V6}6y3bqU4Uoo4 zNsm*IhO?1Cj>&0aV1C(LTwEc0-T0W(XwM|ReDx#aV&*upTSKn#a%xs6a;QX8_fWB@ z$Nb&3VXT(SxMQ7guB};E1E#{83SB<$zT+FDa0PLc`>Ic@nvyr_v7fc6Xn3?f^Uo#y zK6E+NScx%uIp6@BF9gF4DTuN4{vz$c z@f)iW{v|uw(IrU&A1hdU83~9R1PUV666;6T6idF^R;s_uAEBWarR^s5#kJ^V7yh_S z9J{Q9mtLwRj4w(Pq_&J{#}a=2Y+gW?a4s8e5FLR-)I{uC@cIu3?Ji?YJzNCGeZ#R|NQsAYJ&YJxau+F?sxfhAd5!w%HAvPu zyS=dN`!3XHJVS7g<3t-w_j;V`++EjH#qT2bxzfw(h*!!^JuwxU-!--*8%2Af3gm45I73xq z{GP>UfWM*M?_R0J{2??Br~!?|!lWmD{B zY6?|Z7DTM2;2^K67k?!RYy6Q*^CJ}~Mzqa2rSH2TDwJWT)!JgAKf&XJG?+oI+!}ZD z(_ee!N8(gZ-jU@N^(5E2!9Un^EaghdnsY_cFkYm*gJ`qA45daQA&k0Xct}cR9E%42N_Ef~1u@Y8m+VHu$+f zHF8KxCI>u_+gVsj&aQr)S?opN{3@gP4(|)uqof`IY9o035L<*-$T@9kPBx8E7MaVR zx6ceWy0~Mpq9RNLl*ohnZ7w6)(z-*?%kI=Ko z;H%-81f^O}PD9*F&b>>FTrjU+Z}~X(Irh`Ho&ItnQjkR~$N|=w@JKCY=)v^i4oFhU z%A1DFQ$0IjYRmJYDNI+yEFvk3b&EZP4o;_~i-5#C5AKyExChACf1IdX$lrX&6It$w zXrm2kV8l1QO#?EY4|4HaY7%P8%D&b6ju$M0IC02+mwHe?_as;Jjgh^Fe>9(O#1CG# zkj#4|zpOMs<%`Sc&BYTjw{d$cor{^LNbYELRR83@k}A2PBGMj>^*Hbc^lP`mtoWr` zCfJv{m;xt|+Hv6Mm;zr|&R^(*{Zwlq(xE4!wV5;97)J?KA}(x-R&LW|=lU$Sd$B{r z?oVG~)d(9JHFygTg9`Qw%QL4G8QV4#Ao)N9 zKY3l^tJo{zp)@v1&$OGCY{{aaW?!ZlieYvwG>;?J2DSBe)5n|-nR98)bZ@Oz#@PeU zNbt$itcX`Q!P%?=F#=~A&l;4v>dC1rt;g6=?eX?{hz^!&(_*8uw(>W2Py4>k9Hom| z`e)d@7+B?mIU%aC+;kjRi*Llq98u$?jUAM~k7MqPat_vezO#5Q|9x_4fTQe4GCX=z zmPTvjfW!;Xd$XhB5O=EmdT2Suvciv34Y&B2AUVNf^Bk?2T3#9a9zHy68uI4pcNDEL z2nm=3bVgp?(+AqM1F>Jv8((Ybg*H1AvKRC@QNWTob+-Sl{+kcDAiPa_EEv zwc++E2X4B{*@T~v_EoY!D125{Z1~aVUSLtk{1KMSHs4XPS(>FJKUEstGIG}}A;zQX zMOjxqTC@8z^3yQQCN!y8>x5O((o&mrCGow%XV<;Ol6@5vUTEvXH-Ck97{PZRexLf( zjoDZUdD(FIGvB@NgG=~O{ z>^5XQGpmdQW?K%6OH=o~+9(-oHO}YHrPN(yb4@=GEEws~-NLk*TH2sa!;bXwXxHHq zC@beji>dWt5xpzey6Nwr>M2x6{OSM(LgQPV?p`ieb72>OY#ZRI1dI)-zIY{|$LOj5jt_VMb&&pG+GetVn!O^#lt`EK0s&d}KA1v|nu92ZZVx&PMyrs~GZ zw}uyjK-0AR&_PO4p@5Iej*SndJStYTu9C$hNnNN%2?Vdw5vq7;O)jjCjf}(T<%&NV z;&c+l`Ff-VkgQd~>?2S%(p6%w9l9bW3^H6fYM?B1{j{3i?q70}WtgIl^EB3fXu+wi_dPG}L>QOe0}-U)Z4r^_%~lX3cB1pf zG0;2+$#XO8O^oNjhJxK_(H{B)|9iAx#@*|B91L7 zw%?z=SYK@!>*c>|f`);*g3x$Ei&huE-xHh!RoSETIo(EM>3@gOZRWl_&?fdh&bJ1i zlE(|TTG^xbv42mt$iFn@6y#bQOb|RY_ zq5cr214Z28N@r+qzhm!{?1!#0X@ao2nG0T86bE&xc6SOBSHoeZL2#M4k%z`vBUDt8 z8sizm&)Y`xplM_zl*$gs3?x6eSV-tt{ z!(-@r3joJ2p^3+Ki8Ds~4&c+l>kt z7o}@RXIuxs7x;yVy&xT?Ra3(E>esDbI$WZ-5 z96MGGH?&JeJ~dvRI)%OE%X zWFN;;WOq5iRQSncC5-(BY$mGUuokk_;TnxmUxvShZMl_^EMlil4v+4b%Aj+{t5SsaX|=&L1VPtS)J%(#`7=f>K}CXu8`F zz>-_BW#rWCwTL-e2DiNO_RT3iIZz*MYBxgU$vBXdZwXlnx%JK3W@L#yQa-cM*tfEj zbQa!M^%o}kZRR|>DR#oTz;F_B`PH)NJk5T>b51!U(AJT_8aeG=FHU9PKrI!s4(SN` za$`WL`Ug0t6P6mAO>D(LH^B=O)Z%Ry9fr{#pa8h>Fh$`Jb7m zbvwl0xU&~{(_9`vZG4N-O)F*EXMC4l=4;F_<5OK&Y7j`9cEsEF!+o687_A|;Un5&~ z7x}*-c7}!wyDCr%21n;6ETSsGS(Hosx5w$VzaS zGt!&ZFmp1q%jn&TV^8N2`2&Y;(y}UM_-<UZfNdviDmH7nnw+&FFpOfg@(a_8n30MvtL$*w zsDEh9CTb>_E?K`5=0$So|6Sp8q#pD(1>X%_Z610y<`bvJs69jJnIO09@8G^KY~1U3 z6X&8ERY~qb*2bhYS9MM}7D)#&ZHKj!?-Q{okZ}AVvym86lcTb|wW-nBD>_QpOnRCZZa9i}Ki@j??!^BavcaEr8csjR1 z3x3FyVcpC-UQv%^kt2rfa*kmv;(*w|2`a6|exUBpEdl1oBU;=Ix*A|vVx6dfs} z1j`{y(RUsysv*0sbboq<+ag!vBCjs`D#EJ0@#iR)G$A zdE|m68Q+m04X-=Kdw$ubyInuw~7eS=IWes;&Z+iYK%6q2njc{duT7DovNd) z@M}2NB3avzs2~tLC^Gy=jV0XjF}u-LY*+LMq_R;;!QF)9C=l05=&d&u9PqxB<#JX!qHATHrQ+0zRO_~ThEjo&N0)MS0k7iXXz`bd7 zO1UJ8;t0gw7Cvu09a*YOm7kO=weKj{86gzT)yIo1gebZ)MgJ0@6$6{I{ys{oAho~l zm9&eg!3*EK!bzmwvhLAu>lI@?NgK3B_u(wlM&vH%MTysFbH1jkk0F0;n8=h-)p1p_ zyCqZ29o}Ic=cII?J`{7nexD-ZaD*Yvx)@&@d|W{ua)mad!Z3Pus|v4T&>mXoONa2* zRo8x^jb+h7FA^oPQ6!WFDSW0CMo$Xq& z|FU4%l``HTx*Pf^q8Ais0}TL8ZYSjoG$JDzFy<4Fm7cGwYhH?SM@zD~+@M7? zJBesP``CWG7q^_AT77J;zLv-QOH+NI+!yzkEW%>WH0!0l!|@aI?}FT4^<$aV7*PDX z>_oW)ZQlf*7o~m<9{_N*BB^hvKUKj-!AiU=+w;p8lx|-OlYs_uqjitjg|HgPPZvDv zz*+@SqFpEa!bFhMKJ5+E81obgGO}_vG=k3mzNrv#L7T0`w-%J7tMxheiyW#a{r-)% znu#9%E#W?F5D`uSCr%Ck)cyk&yf|+CSZ{sOIR?r?#clmSnfGWc_=b@9RRsVf;T~+r zsz>zAQ2=N9PP&%r4@i{<2*A@}=vDGucP`I7k=*+ON-Sm){@I5J9ucx24`CUkhG>N@ znzjBmica3)l{Caa9>J~=xpTq2qAVZHBkh?65w?(KNP|=&Hq8T0^tKw8Ir0N!XVl+H z^G7VO!pMk?`1VcK6d%6XEQgpJnfjScM#|BI3^z|~H~2ArbbF8QAA_#uqh6lGy7JZ7 zRK&pp*-?NXFK1(mDOrDMnVRRtd8H#bUo3zJ(SaqD5sT-BmtlwFa#QxU=J~ZNthh1f zBhJ=y3L_RvE8Rvi(RS-R6K%LPSz4yW;nlaaQ-2Pg%(#sk02ULA+#BEOfP7l)Lc}3_ z2$f&s>*QJTVBGmpd(?n49W}#d%&GUOwwQ^SU_W86cA{hgH%)~58_!3@Mq*y^zxOW1 zo>uO?we*Ao*;MCCS%x#d6%1? zvRkCSJTH(p(q3MP^>W|V8(|}TW{Zmt6_#`c;q<5V`*sPTZ!vmI1ux-(AjjnHI-ZK`5^?;fg?)ADI;h6RSlw>CoM z2JR3Z`*F)*bNlTiYe=J)c+HUu8fqjtl%gdej&>LWWwsh7SP$mZXmA-~w*{n}PhDij zxM9xn1I4Mb^FrwfuSJV+vycaQ_)HpNHW!3ovj@>La6DmRHKNU%nO}!yH>(gnRkuhSeeo8-sW8|7d^Hx-PS*(F3G~lb#GT~C&%rGW{qq36?liU z{ouF)6OFOALAA~6OvdzcVM0jkd%1+_Lijy-rmNWBD|^32bk9QYQ>Fi?3SpLQF7!Jz z`+?~!$(;&U&q5PyjHhLKyhV6| zn{;9P>U7I?uDvvfB>fik&Esa1JQJaO@Lx3qMSG{~Z1c=t7lznaWI7BOw@t^XLR0ay z%4sT4Kaye1gM+_E9J?EWDXjH+-PgB5EJBDN@gZgQSU8lRm%!duF5bc-BO!0`rHf;ee$qIk1tv`0+6#zw`~>owfJeV@)oa`6c^!ZPiEtMVij z$lSU27{IQActg10fb1WuH6RsfcusW5;N{`P3C3nm-u-e{3dtiLw?yO+{!s3P68diawaYF?4Zk>eG z@7wTKB)ujmo&3RG8-a`4>s1;=QX7G(Ugp+eiPF%CIW z``{Kj*EhAZ`8;WiuQk{hZsN=H(B190YNioX$;~73YmOtgpP7?`>>vi-+wnkw9r}rp z9*D0qWEk!KIXM7AB2IX$3auq;`(2(s7M#DNzM`A7{}c485cp`rzIEM^s!j-otGfWnv4qs zs<^I`6cBVTiXeaJMJ!#`1oM+f)Bo4F_bKicv9a;X$Em&OjG7Bk6 zS4Z=sE}vA!y0X0EJ0OoC04Us2Y&IOZNGZHo2tKde6GMS<=htPbIaB`wvJQJZx}S!X zeH1yuxlAS27paqHhz;dAS(8%ZuI(56!1JTP46^&xY0j$a70Yv?QyuXT)9c%-d@$O+P-=VIdfwhXNq zK1NVP8l>mq_G{#%j<2Qv0Wrz$$_@mm!xJ;;-M`)AjjZE499!M8O(m=6^18|_nr*?Y z!YL%?kn1k$Ly>W^{Ng}v2U#m_q5JZLWNHsTkfJ?2>j!{isMT)ZyyC70IT6UaF8ds< zV2>pr`cp>1o=N+&L;XCfJ6+aI0%>5w8M1;UeC9vhE3a`vTKTz3Yq&|Z5xS*|b;RKs z&>;1(3uo6?$iTLzIDVG&W)Vl3Y z+un3?ZW*dU>slSj^r>t_Anm-2jQiZ=z-Yfu0oP|gS~huJv44c5!7lqPUEh?Wxc0se zL#nS2?{ajFWEP|$U%w*q5BpIS@3?N#&|F`qtunLPs$+GIz7+5AR2pH&G84X4C*oyK zusC!`B^G&tf=@nr6OcF*eUaWly?s9VF`X^h*4rW5%Gl5*ST8~0o$1rfZVkhe<_9$) zYs!rFuZ-k==mu=j&s&kHnfZM3TFaKa*4i_@p9Z?b2l{#GU|O}k=2F5%zL2-mJ$C*9 z9d?c*@&#WN4lx=U>|3ZLq9B*`&I+Zsh8uF0X#xTElAh)SH0VEHH2TR)elfqsfnAJ{ z+x?U}955$zugZAqP!dxiel^M`w@Wdfi|uN5!bWU^%B1vqAgUo1sfw&YYi6}qW7|-M zd~1B$rB{9AJG%rB=CBqdYr6??+S9X(gZ7HziB4R8+mrKC$vG??D- z7NAeL7e01$Kg^n4&4fK`PcwSknayA}r$`neXq8Pxc%akbWJlZaE!ZEwCYU0oe~X&* zT%mhj`o$Q8vcA=BQ898HDk_!a)XX1J+V1cN Tuple[int, int]: + r""" + Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`. + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it + should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch + tensor, it should have shape `[batch, channels, height, width]`. + height (`Optional[int]`, *optional*, defaults to `None`): + The height of the preprocessed image. If `None`, the height of the `image` input will be used. + width (`Optional[int]`, *optional*, defaults to `None`): + The width of the preprocessed image. If `None`, the width of the `image` input will be used. + + Returns: + `Tuple[int, int]`: + A tuple containing the height and width, both resized to the nearest integer multiple of + `vae_scale_factor`. + """ + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + else: + height = image.shape[1] + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + else: + width = image.shape[2] + + if max_side_length is None: + max_side_length = self.max_side_length + + if max_pixels is None: + max_pixels = self.max_pixels + + ratio = 1.0 + if max_side_length is not None: + if height > width: + max_side_length_ratio = max_side_length / height + else: + max_side_length_ratio = max_side_length / width + + cur_pixels = height * width + max_pixels_ratio = (max_pixels / cur_pixels) ** 0.5 + ratio = min(max_pixels_ratio, max_side_length_ratio, 1.0) # do not upscale input image + + new_height, new_width = int(height * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor, int(width * ratio) // self.config.vae_scale_factor * self.config.vae_scale_factor + return new_height, new_width + + def preprocess( + self, + image: PipelineImageInput, + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: Optional[int] = None, + max_side_length: Optional[int] = None, + resize_mode: str = "default", # "default", "fill", "crop" + crops_coords: Optional[Tuple[int, int, int, int]] = None, + ) -> torch.Tensor: + """ + Preprocess the image input. + + Args: + image (`PipelineImageInput`): + The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of + supported formats. + height (`int`, *optional*): + The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default + height. + width (`int`, *optional*): + The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. + resize_mode (`str`, *optional*, defaults to `default`): + The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within + the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will + resize the image to fit within the specified width and height, maintaining the aspect ratio, and then + center the image within the dimensions, filling empty with data from image. If `crop`, will resize the + image to fit within the specified width and height, maintaining the aspect ratio, and then center the + image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only + supported for PIL image input. + crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): + The crop coordinates for each image in the batch. If `None`, will not crop the image. + + Returns: + `torch.Tensor`: + The preprocessed image. + """ + supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) + + # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image + if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3: + if isinstance(image, torch.Tensor): + # if image is a pytorch tensor could have 2 possible shapes: + # 1. batch x height x width: we should insert the channel dimension at position 1 + # 2. channel x height x width: we should insert batch dimension at position 0, + # however, since both channel and batch dimension has same size 1, it is same to insert at position 1 + # for simplicity, we insert a dimension of size 1 at position 1 for both cases + image = image.unsqueeze(1) + else: + # if it is a numpy array, it could have 2 possible shapes: + # 1. batch x height x width: insert channel dimension on last position + # 2. height x width x channel: insert batch dimension on first position + if image.shape[-1] == 1: + image = np.expand_dims(image, axis=0) + else: + image = np.expand_dims(image, axis=-1) + + if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray", + FutureWarning, + ) + image = np.concatenate(image, axis=0) + if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor", + FutureWarning, + ) + image = torch.cat(image, axis=0) + + if not is_valid_image_imagelist(image): + raise ValueError( + f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}" + ) + if not isinstance(image, list): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + if crops_coords is not None: + image = [i.crop(crops_coords) for i in image] + if self.config.do_resize: + height, width = self.get_new_height_width(image[0], height, width, max_pixels, max_side_length) + image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image] + if self.config.do_convert_rgb: + image = [self.convert_to_rgb(i) for i in image] + elif self.config.do_convert_grayscale: + image = [self.convert_to_grayscale(i) for i in image] + image = self.pil_to_numpy(image) # to np + image = self.numpy_to_pt(image) # to pt + + elif isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + + image = self.numpy_to_pt(image) + + height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length) + if self.config.do_resize: + image = self.resize(image, height, width) + + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) + + if self.config.do_convert_grayscale and image.ndim == 3: + image = image.unsqueeze(1) + + channel = image.shape[1] + # don't need any preprocess if the image is latents + if channel == self.config.vae_latent_channels: + return image + + height, width = self.get_new_height_width(image, height, width, max_pixels, max_side_length) + if self.config.do_resize: + image = self.resize(image, height, width) + + # expected range [0,1], normalize to [-1,1] + do_normalize = self.config.do_normalize + if do_normalize and image.min() < 0: + warnings.warn( + "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " + f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", + FutureWarning, + ) + do_normalize = False + if do_normalize: + image = self.normalize(image) + + if self.config.do_binarize: + image = self.binarize(image) + + return image diff --git a/modules/omnigen2/import_utils.py b/modules/omnigen2/import_utils.py new file mode 100644 index 000000000..148b88684 --- /dev/null +++ b/modules/omnigen2/import_utils.py @@ -0,0 +1,46 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Import utilities: Utilities related to imports and our lazy inits. +""" + +import importlib.util +import sys + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + +def _is_package_available(pkg_name: str): + pkg_exists = importlib.util.find_spec(pkg_name) is not None + pkg_version = "N/A" + + if pkg_exists: + try: + pkg_version = importlib_metadata.version(pkg_name) + except (ImportError, importlib_metadata.PackageNotFoundError): + pkg_exists = False + + return pkg_exists, pkg_version + +_triton_available, _triton_version = _is_package_available("triton") +_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") + +def is_triton_available(): + return _triton_available + +def is_flash_attn_available(): + return _flash_attn_available diff --git a/modules/omnigen2/models/__init__.py b/modules/omnigen2/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/modules/omnigen2/models/attention_processor.py b/modules/omnigen2/models/attention_processor.py new file mode 100644 index 000000000..592d6b503 --- /dev/null +++ b/modules/omnigen2/models/attention_processor.py @@ -0,0 +1,357 @@ +""" +OmniGen2 Attention Processor Module + +Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import warnings +import math +from typing import Optional, Tuple, Dict, Any + +import torch +import torch.nn.functional as F +from einops import repeat + +from ..import_utils import is_flash_attn_available + +if is_flash_attn_available(): + from flash_attn import flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +else: + warnings.warn("Cannot import flash_attn, install flash_attn to use Flash2Varlen attention for better performance") + + +from diffusers.models.attention_processor import Attention +from .embeddings import apply_rotary_emb + + +class OmniGen2AttnProcessorFlash2Varlen: + """ + Processor for implementing scaled dot-product attention with flash attention and variable length sequences. + + This processor implements: + - Flash attention with variable length sequences + - Rotary position embeddings (RoPE) + - Query-Key normalization + - Proportional attention scaling + + Args: + None + """ + + def __init__(self) -> None: + """Initialize the attention processor.""" + if not is_flash_attn_available(): + raise ImportError( + "OmniGen2AttnProcessorFlash2Varlen requires flash_attn. " + "Please install flash_attn." + ) + + def _upad_input( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + num_heads: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: + """ + Unpad the input tensors for flash attention. + + Args: + query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim) + key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) + value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) + attention_mask: Attention mask tensor of shape (batch_size, seq_len) + query_length: Length of the query sequence + num_heads: Number of attention heads + + Returns: + Tuple containing: + - Unpadded query tensor + - Unpadded key tensor + - Unpadded value tensor + - Query indices + - Tuple of cumulative sequence lengths for query and key + - Tuple of maximum sequence lengths for query and key + """ + def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: + """Helper function to get unpadding data from attention mask.""" + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return indices, cu_seqlens, max_seqlen_in_batch + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + # Unpad key and value layers + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + + # Handle different query length cases + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Process attention computation with flash attention. + + Args: + attn: Attention module + hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) + encoder_hidden_states: Encoder hidden states tensor + attention_mask: Optional attention mask tensor + image_rotary_emb: Optional rotary embeddings for image tokens + base_sequence_length: Optional base sequence length for proportional attention + + Returns: + torch.Tensor: Processed hidden states after attention computation + """ + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + # Reshape tensors for attention computation + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply Query-Key normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply Rotary Position Embeddings + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, use_real=False) + key = apply_rotary_emb(key, image_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Calculate attention scale + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + # Unpad input for flash attention + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + # Handle different number of heads + if kv_heads < attn.heads: + key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) + value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) + + # Apply flash attention + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=0.0, + causal=False, + softmax_scale=softmax_scale, + ) + + # Pad output and apply final transformations + hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length) + hidden_states = hidden_states.flatten(-2) + hidden_states = hidden_states.type_as(query) + + # Apply output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + +class OmniGen2AttnProcessor: + """ + Processor for implementing scaled dot-product attention with flash attention and variable length sequences. + + This processor is optimized for PyTorch 2.0 and implements: + - Flash attention with variable length sequences + - Rotary position embeddings (RoPE) + - Query-Key normalization + - Proportional attention scaling + + Args: + None + + Raises: + ImportError: If PyTorch version is less than 2.0 + """ + + def __init__(self) -> None: + """Initialize the attention processor.""" + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. " + "Please upgrade PyTorch to version 2.0 or later." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + """ + Process attention computation with flash attention. + + Args: + attn: Attention module + hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) + encoder_hidden_states: Encoder hidden states tensor + attention_mask: Optional attention mask tensor + image_rotary_emb: Optional rotary embeddings for image tokens + base_sequence_length: Optional base sequence length for proportional attention + + Returns: + torch.Tensor: Processed hidden states after attention computation + """ + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + # Reshape tensors for attention computation + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply Query-Key normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply Rotary Position Embeddings + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, use_real=False) + key = apply_rotary_emb(key, image_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Calculate attention scale + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + if attention_mask is not None: + attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6 + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=softmax_scale + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.type_as(query) + + # Apply output projection + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states diff --git a/modules/omnigen2/models/embeddings.py b/modules/omnigen2/models/embeddings.py new file mode 100644 index 000000000..047526f69 --- /dev/null +++ b/modules/omnigen2/models/embeddings.py @@ -0,0 +1,126 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + + +from diffusers.models.activations import get_activation + + +class TimestepEmbedding(nn.Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int = None, + post_act_fn: Optional[str] = None, + cond_proj_dim=None, + sample_proj_bias=True, + ): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) + + if cond_proj_dim is not None: + self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) + else: + self.cond_proj = None + + self.act = get_activation(act_fn) + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) + + if post_act_fn is None: + self.post_act = None + else: + self.post_act = get_activation(post_act_fn) + + self.initialize_weights() + + def initialize_weights(self): + nn.init.normal_(self.linear_1.weight, std=0.02) + nn.init.zeros_(self.linear_1.bias) + nn.init.normal_(self.linear_2.weight, std=0.02) + nn.init.zeros_(self.linear_2.bias) + + def forward(self, sample, condition=None): + if condition is not None: + sample = sample + self.cond_proj(condition) + sample = self.linear_1(sample) + + if self.act is not None: + sample = self.act(sample) + + sample = self.linear_2(sample) + + if self.post_act is not None: + sample = self.post_act(sample) + return sample + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings + to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are + reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting + tensors contain rotary embeddings and are returned as real tensors. + + Args: + x (`torch.Tensor`): + Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply + freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + if use_real: + cos, sin = freqs_cis # [S, D] + cos = cos[None, None] + sin = sin[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + if use_real_unbind_dim == -1: + # Used for flux, cogvideox, hunyuan-dit + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Used for Stable Audio, OmniGen and CogView4 + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + else: + # used for lumina + # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + + return x_out.type_as(x) diff --git a/modules/omnigen2/models/transformers/__init__.py b/modules/omnigen2/models/transformers/__init__.py new file mode 100644 index 000000000..b2a23df90 --- /dev/null +++ b/modules/omnigen2/models/transformers/__init__.py @@ -0,0 +1,3 @@ +from .transformer_omnigen2 import OmniGen2Transformer2DModel + +__all__ = ["OmniGen2Transformer2DModel"] diff --git a/modules/omnigen2/models/transformers/block_lumina2.py b/modules/omnigen2/models/transformers/block_lumina2.py new file mode 100644 index 000000000..e6d36e0f1 --- /dev/null +++ b/modules/omnigen2/models/transformers/block_lumina2.py @@ -0,0 +1,217 @@ + +# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Optional, Tuple + +import torch +import torch.nn as nn + +from diffusers.models.embeddings import Timesteps +from ..embeddings import TimestepEmbedding +from ...import_utils import is_flash_attn_available, is_triton_available + +if is_triton_available(): + from ...triton_layer_norm import RMSNorm +else: + from torch.nn import RMSNorm + warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance") + +if is_flash_attn_available(): + from flash_attn.ops.activations import swiglu +else: + from .components import swiglu + warnings.warn("Cannot import flash_attn, install flash_attn to use fused SwiGLU for better performance") + +# try: +# from flash_attn.ops.activations import swiglu as fused_swiglu +# FUSEDSWIGLU_AVALIBLE = True +# except ImportError: + +# FUSEDSWIGLU_AVALIBLE = False +# warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") + +class LuminaRMSNormZero(nn.Module): + """ + Norm layer adaptive RMS normalization zero. + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + """ + + def __init__( + self, + embedding_dim: int, + norm_eps: float, + norm_elementwise_affine: bool, + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear( + min(embedding_dim, 1024), + 4 * embedding_dim, + bias=True, + ) + + self.norm = RMSNorm(embedding_dim, eps=norm_eps) + + def forward( + self, + x: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + return x, gate_msa, scale_mlp, gate_mlp + + +class LuminaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + out_dim: Optional[int] = None, + ): + super().__init__() + + # AdaLN + self.silu = nn.SiLU() + self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + self.linear_2 = None + if out_dim is not None: + self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) + + def forward( + self, + x: torch.Tensor, + conditioning_embedding: torch.Tensor, + ) -> torch.Tensor: + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) + scale = emb + x = self.norm(x) * (1 + scale)[:, None, :] + + if self.linear_2 is not None: + x = self.linear_2(x) + + return x + + +class LuminaFeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + hidden_size (`int`): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + intermediate_size (`int`): The intermediate dimension of the feedforward layer. + multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple + of this value. + ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden + dimension. Defaults to None. + """ + + def __init__( + self, + dim: int, + inner_dim: int, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + ): + super().__init__() + self.swiglu = swiglu + + # custom hidden_size factor multiplier + if ffn_dim_multiplier is not None: + inner_dim = int(ffn_dim_multiplier * inner_dim) + inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) + + self.linear_1 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + self.linear_2 = nn.Linear( + inner_dim, + dim, + bias=False, + ) + self.linear_3 = nn.Linear( + dim, + inner_dim, + bias=False, + ) + + def forward(self, x): + h1, h2 = self.linear_1(x), self.linear_3(x) + return self.linear_2(self.swiglu(h1, h2)) + + +class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): + def __init__( + self, + hidden_size: int = 4096, + text_feat_dim: int = 2048, + frequency_embedding_size: int = 256, + norm_eps: float = 1e-5, + timestep_scale: float = 1.0, + ) -> None: + super().__init__() + + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale + ) + + self.timestep_embedder = TimestepEmbedding( + in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) + ) + + self.caption_embedder = nn.Sequential( + RMSNorm(text_feat_dim, eps=norm_eps), + nn.Linear(text_feat_dim, hidden_size, bias=True), + ) + + self._initialize_weights() + + def _initialize_weights(self): + nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02) + nn.init.zeros_(self.caption_embedder[1].bias) + + def forward( + self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + timestep_proj = self.time_proj(timestep).to(dtype=dtype) + time_embed = self.timestep_embedder(timestep_proj) + caption_embed = self.caption_embedder(text_hidden_states) + return time_embed, caption_embed diff --git a/modules/omnigen2/models/transformers/components.py b/modules/omnigen2/models/transformers/components.py new file mode 100644 index 000000000..05dd6f5f3 --- /dev/null +++ b/modules/omnigen2/models/transformers/components.py @@ -0,0 +1,4 @@ +import torch.nn.functional as F + +def swiglu(x, y): + return F.silu(x.float(), inplace=False).to(x.dtype) * y diff --git a/modules/omnigen2/models/transformers/repo.py b/modules/omnigen2/models/transformers/repo.py new file mode 100644 index 000000000..8f7c47566 --- /dev/null +++ b/modules/omnigen2/models/transformers/repo.py @@ -0,0 +1,129 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn + +from einops import repeat +from diffusers.models.embeddings import get_1d_rotary_pos_embed + +class OmniGen2RotaryPosEmbed(nn.Module): + def __init__(self, theta: int, + axes_dim: Tuple[int, int, int], + axes_lens: Tuple[int, int, int] = (300, 512, 512), + patch_size: int = 2): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.axes_lens = axes_lens + self.patch_size = patch_size + + @staticmethod + def get_freqs_cis(axes_dim: Tuple[int, int, int], + axes_lens: Tuple[int, int, int], + theta: int) -> List[torch.Tensor]: + freqs_cis = [] + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): + emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype) + freqs_cis.append(emb) + return freqs_cis + + def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor: + device = ids.device + if ids.device.type == "mps": + ids = ids.to("cpu") + + result = [] + for i in range(len(self.axes_dim)): + freqs = freqs_cis[i].to(ids.device) + index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) + result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) + return torch.cat(result, dim=-1).to(device) + + def forward( + self, + freqs_cis, + attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device + ): + batch_size = len(attention_mask) + p = self.patch_size + + encoder_seq_len = attention_mask.shape[1] + l_effective_cap_len = attention_mask.sum(dim=1).tolist() + + seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)] + + max_seq_len = max(seq_lengths) + max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) + max_img_len = max(l_effective_img_len) + + # Create position IDs + position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) + + for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): + # add text position ids + position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3") + + pe_shift = cap_seq_len + pe_shift_len = cap_seq_len + + if ref_img_sizes[i] is not None: + for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]): + H, W = ref_img_size + ref_H_tokens, ref_W_tokens = H // p, W // p + assert ref_H_tokens * ref_W_tokens == ref_img_len + # add image position ids + + row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten() + col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten() + position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift + position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids + position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids + + pe_shift += max(ref_H_tokens, ref_W_tokens) + pe_shift_len += ref_img_len + + H, W = img_sizes[i] + H_tokens, W_tokens = H // p, W // p + assert H_tokens * W_tokens == l_effective_img_len[i] + + row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten() + col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten() + + assert pe_shift_len + l_effective_img_len[i] == seq_len + position_ids[i, pe_shift_len: seq_len, 0] = pe_shift + position_ids[i, pe_shift_len: seq_len, 1] = row_ids + position_ids[i, pe_shift_len: seq_len, 2] = col_ids + + # Get combined rotary embeddings + freqs_cis = self._get_freqs_cis(freqs_cis, position_ids) + + # create separate rotary embeddings for captions and images + cap_freqs_cis = torch.zeros( + batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + ref_img_freqs_cis = torch.zeros( + batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + img_freqs_cis = torch.zeros( + batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype + ) + + for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)): + cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] + ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)] + img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len] + + return ( + cap_freqs_cis, + ref_img_freqs_cis, + img_freqs_cis, + freqs_cis, + l_effective_cap_len, + seq_lengths, + ) diff --git a/modules/omnigen2/models/transformers/transformer_omnigen2.py b/modules/omnigen2/models/transformers/transformer_omnigen2.py new file mode 100644 index 000000000..fe324b0d2 --- /dev/null +++ b/modules/omnigen2/models/transformers/transformer_omnigen2.py @@ -0,0 +1,617 @@ +import warnings +import itertools +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from einops import rearrange + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.attention_processor import Attention +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin + +from ..attention_processor import OmniGen2AttnProcessorFlash2Varlen, OmniGen2AttnProcessor +from .repo import OmniGen2RotaryPosEmbed +from .block_lumina2 import LuminaLayerNormContinuous, LuminaRMSNormZero, LuminaFeedForward, Lumina2CombinedTimestepCaptionEmbedding + +from ...import_utils import is_triton_available, is_flash_attn_available + +if is_triton_available(): + from ...triton_layer_norm import RMSNorm +else: + from torch.nn import RMSNorm + +logger = logging.get_logger(__name__) + + +class OmniGen2TransformerBlock(nn.Module): + """ + Transformer block for OmniGen2 model. + + This block implements a transformer layer with: + - Multi-head attention with flash attention + - Feed-forward network with SwiGLU activation + - RMS normalization + - Optional modulation for conditional generation + + Args: + dim: Dimension of the input and output tensors + num_attention_heads: Number of attention heads + num_kv_heads: Number of key-value heads + multiple_of: Multiple of which the hidden dimension should be + ffn_dim_multiplier: Multiplier for the feed-forward network dimension + norm_eps: Epsilon value for normalization layers + modulation: Whether to use modulation for conditional generation + use_fused_rms_norm: Whether to use fused RMS normalization + use_fused_swiglu: Whether to use fused SwiGLU activation + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + modulation: bool = True, + ) -> None: + """Initialize the transformer block.""" + super().__init__() + self.head_dim = dim // num_attention_heads + self.modulation = modulation + + try: + processor = OmniGen2AttnProcessorFlash2Varlen() + except ImportError: + processor = OmniGen2AttnProcessor() + + # Initialize attention layer + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="rms_norm", + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=processor, + ) + + # Initialize feed-forward network + self.feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier + ) + + # Initialize normalization layers + if modulation: + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, + norm_eps=norm_eps, + norm_elementwise_affine=True + ) + else: + self.norm1 = RMSNorm(dim, eps=norm_eps) + + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + self.norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.initialize_weights() + + def initialize_weights(self) -> None: + """ + Initialize the weights of the transformer block. + + Uses Xavier uniform initialization for linear layers and zero initialization for biases. + """ + nn.init.xavier_uniform_(self.attn.to_q.weight) + nn.init.xavier_uniform_(self.attn.to_k.weight) + nn.init.xavier_uniform_(self.attn.to_v.weight) + nn.init.xavier_uniform_(self.attn.to_out[0].weight) + + nn.init.xavier_uniform_(self.feed_forward.linear_1.weight) + nn.init.xavier_uniform_(self.feed_forward.linear_2.weight) + nn.init.xavier_uniform_(self.feed_forward.linear_3.weight) + + if self.modulation: + nn.init.zeros_(self.norm1.linear.weight) + nn.init.zeros_(self.norm1.linear.bias) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + temb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of the transformer block. + + Args: + hidden_states: Input hidden states tensor + attention_mask: Attention mask tensor + image_rotary_emb: Rotary embeddings for image tokens + temb: Optional timestep embedding tensor + + Returns: + torch.Tensor: Output hidden states after transformer block processing + """ + import time + if self.modulation: + if temb is None: + raise ValueError("temb must be provided when modulation is enabled") + + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) + hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) + else: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) + hidden_states = hidden_states + self.ffn_norm2(mlp_output) + + return hidden_states + + +class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): + """ + OmniGen2 Transformer 2D Model. + + A transformer-based diffusion model for image generation with: + - Patch-based image processing + - Rotary position embeddings + - Multi-head attention + - Conditional generation support + + Args: + patch_size: Size of image patches + in_channels: Number of input channels + out_channels: Number of output channels (defaults to in_channels) + hidden_size: Size of hidden layers + num_layers: Number of transformer layers + num_refiner_layers: Number of refiner layers + num_attention_heads: Number of attention heads + num_kv_heads: Number of key-value heads + multiple_of: Multiple of which the hidden dimension should be + ffn_dim_multiplier: Multiplier for feed-forward network dimension + norm_eps: Epsilon value for normalization layers + axes_dim_rope: Dimensions for rotary position embeddings + axes_lens: Lengths for rotary position embeddings + text_feat_dim: Dimension of text features + timestep_scale: Scale factor for timestep embeddings + use_fused_rms_norm: Whether to use fused RMS normalization + use_fused_swiglu: Whether to use fused SwiGLU activation + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["Omnigen2TransformerBlock"] + _skip_layerwise_casting_patterns = ["x_embedder", "norm"] + + @register_to_config + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + out_channels: Optional[int] = None, + hidden_size: int = 2304, + num_layers: int = 26, + num_refiner_layers: int = 2, + num_attention_heads: int = 24, + num_kv_heads: int = 8, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + axes_dim_rope: Tuple[int, int, int] = (32, 32, 32), + axes_lens: Tuple[int, int, int] = (300, 512, 512), + text_feat_dim: int = 1024, + timestep_scale: float = 1.0 + ) -> None: + """Initialize the OmniGen2 transformer model.""" + super().__init__() + + # Validate configuration + if (hidden_size // num_attention_heads) != sum(axes_dim_rope): + raise ValueError( + f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) " + f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})" + ) + + self.out_channels = out_channels or in_channels + + # Initialize embeddings + self.rope_embedder = OmniGen2RotaryPosEmbed( + theta=10000, + axes_dim=axes_dim_rope, + axes_lens=axes_lens, + patch_size=patch_size, + ) + + self.x_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.ref_image_patch_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + ) + + self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( + hidden_size=hidden_size, + text_feat_dim=text_feat_dim, + norm_eps=norm_eps, + timestep_scale=timestep_scale + ) + + # Initialize transformer blocks + self.noise_refiner = nn.ModuleList([ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True + ) + for _ in range(num_refiner_layers) + ]) + + self.ref_image_refiner = nn.ModuleList([ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True + ) + for _ in range(num_refiner_layers) + ]) + + self.context_refiner = nn.ModuleList( + [ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=False + ) + for _ in range(num_refiner_layers) + ] + ) + + # 3. Transformer blocks + self.layers = nn.ModuleList( + [ + OmniGen2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True + ) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + bias=True, + out_dim=patch_size * patch_size * self.out_channels + ) + + # Add learnable embeddings to distinguish different images + self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images + + self.gradient_checkpointing = False + + self.initialize_weights() + + def initialize_weights(self) -> None: + """ + Initialize the weights of the model. + + Uses Xavier uniform initialization for linear layers. + """ + nn.init.xavier_uniform_(self.x_embedder.weight) + nn.init.constant_(self.x_embedder.bias, 0.0) + + nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight) + nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0) + + nn.init.zeros_(self.norm_out.linear_1.weight) + nn.init.zeros_(self.norm_out.linear_1.bias) + nn.init.zeros_(self.norm_out.linear_2.weight) + nn.init.zeros_(self.norm_out.linear_2.bias) + + nn.init.normal_(self.image_index_embedding, std=0.02) + + def img_patch_embed_and_refine( + self, + hidden_states, + ref_image_hidden_states, + padded_img_mask, + padded_ref_img_mask, + noise_rotary_emb, + ref_img_rotary_emb, + l_effective_ref_img_len, + l_effective_img_len, + temb + ): + batch_size = len(hidden_states) + max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)]) + + hidden_states = self.x_embedder(hidden_states) + ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states) + + for i in range(batch_size): + shift = 0 + for j, ref_img_len in enumerate(l_effective_ref_img_len[i]): + ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j] + shift += ref_img_len + + for layer in self.noise_refiner: + hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) + + flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len)) + num_ref_images = len(flat_l_effective_ref_img_len) + max_ref_img_len = max(flat_l_effective_ref_img_len) + + batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool) + batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size) + batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype) + batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype) + + # sequence of ref imgs to batch + idx = 0 + for i in range(batch_size): + shift = 0 + for ref_img_len in l_effective_ref_img_len[i]: + batch_ref_img_mask[idx, :ref_img_len] = True + batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len] + batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len] + batch_temb[idx] = temb[i] + shift += ref_img_len + idx += 1 + + # refine ref imgs separately + for layer in self.ref_image_refiner: + batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb) + + # batch of ref imgs to sequence + idx = 0 + for i in range(batch_size): + shift = 0 + for ref_img_len in l_effective_ref_img_len[i]: + ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len] + shift += ref_img_len + idx += 1 + + combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size) + for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)): + combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)] + combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len] + + return combined_img_hidden_states + + def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states): + batch_size = len(hidden_states) + p = self.config.patch_size + device = hidden_states[0].device + + img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] + l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes] + + if ref_image_hidden_states is not None: + ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states] + l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes] + else: + ref_img_sizes = [None for _ in range(batch_size)] + l_effective_ref_img_len = [[0] for _ in range(batch_size)] + + max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) + max_img_len = max(l_effective_img_len) + + # ref image patch embeddings + flat_ref_img_hidden_states = [] + for i in range(batch_size): + if ref_img_sizes[i] is not None: + imgs = [] + for ref_img in ref_image_hidden_states[i]: + C, H, W = ref_img.size() + ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p) + imgs.append(ref_img) + + img = torch.cat(imgs, dim=0) + flat_ref_img_hidden_states.append(img) + else: + flat_ref_img_hidden_states.append(None) + + # image patch embeddings + flat_hidden_states = [] + for i in range(batch_size): + img = hidden_states[i] + C, H, W = img.size() + + img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p) + flat_hidden_states.append(img) + + padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype) + padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device) + for i in range(batch_size): + if ref_img_sizes[i] is not None: + padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i] + padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True + + padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype) + padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device) + for i in range(batch_size): + padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i] + padded_img_mask[i, :l_effective_img_len[i]] = True + + return ( + padded_hidden_states, + padded_ref_img_hidden_states, + padded_img_mask, + padded_ref_img_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + ) + + def forward( + self, + hidden_states: Union[torch.Tensor, List[torch.Tensor]], + timestep: torch.Tensor, + text_hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + text_attention_mask: torch.Tensor, + ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # 1. Condition, positional & patch embedding + batch_size = len(hidden_states) + is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor) + + if is_hidden_states_tensor: + assert hidden_states.ndim == 4 + hidden_states = [_hidden_states for _hidden_states in hidden_states] + + device = hidden_states[0].device + + temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype) + + ( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states) + + ( + context_rotary_emb, + ref_img_rotary_emb, + noise_rotary_emb, + rotary_emb, + encoder_seq_lengths, + seq_lengths, + ) = self.rope_embedder( + freqs_cis, + text_attention_mask, + l_effective_ref_img_len, + l_effective_img_len, + ref_img_sizes, + img_sizes, + device, + ) + + # 2. Context refinement + for layer in self.context_refiner: + text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) + + combined_img_hidden_states = self.img_patch_embed_and_refine( + hidden_states, + ref_image_hidden_states, + img_mask, + ref_img_mask, + noise_rotary_emb, + ref_img_rotary_emb, + l_effective_ref_img_len, + l_effective_img_len, + temb, + ) + + # 3. Joint Transformer blocks + max_seq_len = max(seq_lengths) + + attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) + joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) + for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): + attention_mask[i, :seq_len] = True + joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len] + joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len] + + hidden_states = joint_hidden_states + + for layer_idx, layer in enumerate(self.layers): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + layer, hidden_states, attention_mask, rotary_emb, temb + ) + else: + hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) + + # 4. Output norm & projection + hidden_states = self.norm_out(hidden_states, temb) + + p = self.config.patch_size + output = [] + for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)): + height, width = img_size + output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p)) + if is_hidden_states_tensor: + output = torch.stack(output, dim=0) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return output + return Transformer2DModelOutput(sample=output) diff --git a/modules/omnigen2/pipeline_omnigen2.py b/modules/omnigen2/pipeline_omnigen2.py new file mode 100644 index 000000000..a7d6c2ca3 --- /dev/null +++ b/modules/omnigen2/pipeline_omnigen2.py @@ -0,0 +1,718 @@ +""" +OmniGen2 Diffusion Pipeline + +Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Any, Dict, List, Optional, Tuple, Union +from dataclasses import dataclass +import inspect +import numpy as np +import torch +import torch.nn.functional as F +import PIL.Image + +from transformers import Qwen2_5_VLForConditionalGeneration +from diffusers.utils import BaseOutput +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + is_torch_xla_available, + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + +from .models.transformers import OmniGen2Transformer2DModel +from .models.transformers.repo import OmniGen2RotaryPosEmbed +from .image_processor import OmniGen2ImageProcessor + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +@dataclass +class FMPipelineOutput(BaseOutput): + """ + Output class for OmniGen2 pipeline. + + Args: + images (Union[List[PIL.Image.Image], np.ndarray]): + List of denoised PIL images of length `batch_size` or numpy array of shape + `(batch_size, height, width, num_channels)`. Contains the generated images. + """ + images: Union[List[PIL.Image.Image], np.ndarray] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class OmniGen2Pipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using OmniGen2. + + This pipeline implements a text-to-image generation model that uses: + - Qwen2.5-VL for text encoding + - A custom transformer architecture for image generation + - VAE for image encoding/decoding + - FlowMatchEulerDiscreteScheduler for noise scheduling + + Args: + transformer (OmniGen2Transformer2DModel): The transformer model for image generation. + vae (AutoencoderKL): The VAE model for image encoding/decoding. + scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for noise scheduling. + text_encoder (Qwen2_5_VLModel): The text encoder model. + tokenizer (Union[Qwen2Tokenizer, Qwen2TokenizerFast]): The tokenizer for text processing. + """ + + model_cpu_offload_seq = "mllm->transformer->vae" + + def __init__( + self, + transformer: OmniGen2Transformer2DModel, + vae: AutoencoderKL, + scheduler: FlowMatchEulerDiscreteScheduler, + mllm: Qwen2_5_VLForConditionalGeneration, + processor, + ) -> None: + """ + Initialize the OmniGen2 pipeline. + + Args: + transformer: The transformer model for image generation. + vae: The VAE model for image encoding/decoding. + scheduler: The scheduler for noise scheduling. + text_encoder: The text encoder model. + tokenizer: The tokenizer for text processing. + """ + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + scheduler=scheduler, + mllm=mllm, + processor=processor + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = OmniGen2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2, do_resize=True) + self.default_sample_size = 128 + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator], + latents: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Prepare the initial latents for the diffusion process. + + Args: + batch_size: The number of images to generate. + num_channels_latents: The number of channels in the latent space. + height: The height of the generated image. + width: The width of the generated image. + dtype: The data type of the latents. + device: The device to place the latents on. + generator: The random number generator to use. + latents: Optional pre-computed latents to use instead of random initialization. + + Returns: + torch.FloatTensor: The prepared latents tensor. + """ + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + return latents + + def encode_vae(self, img: torch.FloatTensor) -> torch.FloatTensor: + """ + Encode an image into the VAE latent space. + + Args: + img: The input image tensor to encode. + + Returns: + torch.FloatTensor: The encoded latent representation. + """ + z0 = self.vae.encode(img.to(dtype=self.vae.dtype)).latent_dist.sample() + if self.vae.config.shift_factor is not None: + z0 = z0 - self.vae.config.shift_factor + if self.vae.config.scaling_factor is not None: + z0 = z0 * self.vae.config.scaling_factor + z0 = z0.to(dtype=self.vae.dtype) + return z0 + + def prepare_image( + self, + images: Union[List[PIL.Image.Image], PIL.Image.Image], + batch_size: int, + num_images_per_prompt: int, + max_pixels: int, + max_side_length: int, + device: torch.device, + dtype: torch.dtype, + ) -> List[Optional[torch.FloatTensor]]: + """ + Prepare input images for processing by encoding them into the VAE latent space. + + Args: + images: Single image or list of images to process. + batch_size: The number of images to generate per prompt. + num_images_per_prompt: The number of images to generate for each prompt. + device: The device to place the encoded latents on. + dtype: The data type of the encoded latents. + + Returns: + List[Optional[torch.FloatTensor]]: List of encoded latent representations for each image. + """ + if batch_size == 1: + images = [images] + latents = [] + for i, img in enumerate(images): + if img is not None and len(img) > 0: + ref_latents = [] + for j, img_j in enumerate(img): + img_j = self.image_processor.preprocess(img_j, max_pixels=max_pixels, max_side_length=max_side_length) + ref_latents.append(self.encode_vae(img_j.to(device=device)).squeeze(0)) + else: + ref_latents = None + for _ in range(num_images_per_prompt): + latents.append(ref_latents) + + return latents + + def _get_qwen2_prompt_embeds( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + max_sequence_length: int = 256, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get prompt embeddings from the Qwen2 text encoder. + + Args: + prompt: The prompt or list of prompts to encode. + device: The device to place the embeddings on. If None, uses the pipeline's device. + max_sequence_length: Maximum sequence length for tokenization. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - The prompt embeddings tensor + - The attention mask tensor + + Raises: + Warning: If the input text is truncated due to sequence length limitations. + """ + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + # text_inputs = self.processor.tokenizer( + # prompt, + # padding="max_length", + # max_length=max_sequence_length, + # truncation=True, + # return_tensors="pt", + # ) + text_inputs = self.processor.tokenizer( + prompt, + padding="longest", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + untruncated_ids = self.processor.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.processor.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because Gemma can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.mllm( + text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-1] + + if self.mllm is not None: + dtype = self.mllm.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + def _apply_chat_template(self, prompt: str): + prompt = [ + { + "role": "system", + "content": "You are a helpful assistant that generates high-quality images based on user instructions.", + }, + {"role": "user", "content": prompt}, + ] + prompt = self.processor.tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=False) + return prompt + + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 256, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Lumina-T2I, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. + max_sequence_length (`int`, defaults to `256`): + Maximum sequence length to use for the prompt. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [self._apply_chat_template(_prompt) for _prompt in prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_qwen2_prompt_embeds( + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length + ) + + batch_size, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) + + # Get negative embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else "" + + # Normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt = [self._apply_chat_template(_negative_prompt) for _negative_prompt in negative_prompt] + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + negative_prompt_embeds, negative_prompt_attention_mask = self._get_qwen2_prompt_embeds( + prompt=negative_prompt, + device=device, + max_sequence_length=max_sequence_length, + ) + + batch_size, seq_len, _ = negative_prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + batch_size * num_images_per_prompt, -1 + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def text_guidance_scale(self): + return self._text_guidance_scale + + @property + def image_guidance_scale(self): + return self._image_guidance_scale + + @property + def cfg_range(self): + return self._cfg_range + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.LongTensor] = None, + negative_prompt_attention_mask: Optional[torch.LongTensor] = None, + max_sequence_length: Optional[int] = None, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + input_images: Optional[List[PIL.Image.Image]] = None, + num_images_per_prompt: int = 1, + height: Optional[int] = None, + width: Optional[int] = None, + max_pixels: int = 2048 * 2048, + max_input_image_side_length: int = 2048, + align_res: bool = True, + num_inference_steps: int = 28, + text_guidance_scale: float = 4.0, + image_guidance_scale: float = 1.0, + cfg_range: Tuple[float, float] = (0.0, 1.0), + attention_kwargs: Optional[Dict[str, Any]] = None, + timesteps: List[int] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + verbose: bool = False, + step_func=None, + ): + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self._text_guidance_scale = text_guidance_scale + self._image_guidance_scale = image_guidance_scale + self._cfg_range = cfg_range + self._attention_kwargs = attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.text_guidance_scale > 1.0, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + ) + + dtype = self.vae.dtype + # 3. Prepare control image + ref_latents = self.prepare_image( + images=input_images, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + max_pixels=max_pixels, + max_side_length=max_input_image_side_length, + device=device, + dtype=dtype, + ) + + if input_images is None: + input_images = [] + + if len(input_images) == 1 and align_res: + width, height = ref_latents[0][0].shape[-1] * self.vae_scale_factor, ref_latents[0][0].shape[-2] * self.vae_scale_factor + ori_width, ori_height = width, height + else: + ori_width, ori_height = width, height + + cur_pixels = height * width + ratio = (max_pixels / cur_pixels) ** 0.5 + ratio = min(ratio, 1.0) + + height, width = int(height * ratio) // 16 * 16, int(width * ratio) // 16 * 16 + + if len(input_images) == 0: + self._image_guidance_scale = 1 + + # 4. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis( + self.transformer.config.axes_dim_rope, + self.transformer.config.axes_lens, + theta=10000, + ) + + image = self.processing( + latents=latents, + ref_latents=ref_latents, + prompt_embeds=prompt_embeds, + freqs_cis=freqs_cis, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + num_inference_steps=num_inference_steps, + timesteps=timesteps, + device=device, + dtype=dtype, + verbose=verbose, + step_func=step_func, + ) + + image = F.interpolate(image, size=(ori_height, ori_width), mode='bilinear') + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return image + else: + return FMPipelineOutput(images=image) + + def processing( + self, + latents, + ref_latents, + prompt_embeds, + freqs_cis, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + num_inference_steps, + timesteps, + device, + dtype, + verbose, + step_func=None + ): + batch_size = latents.shape[0] + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + num_tokens=latents.shape[-2] * latents.shape[-1] + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + model_pred = self.predict( + t=t, + latents=latents, + prompt_embeds=prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=prompt_attention_mask, + ref_image_hidden_states=ref_latents, + ) + text_guidance_scale = self.text_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + image_guidance_scale = self.image_guidance_scale if self.cfg_range[0] <= i / len(timesteps) <= self.cfg_range[1] else 1.0 + + if text_guidance_scale > 1.0 and image_guidance_scale > 1.0: + model_pred_ref = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=ref_latents, + ) + + if image_guidance_scale != 1: + model_pred_uncond = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, + ) + else: + model_pred_uncond = torch.zeros_like(model_pred) + + model_pred = model_pred_uncond + image_guidance_scale * (model_pred_ref - model_pred_uncond) + \ + text_guidance_scale * (model_pred - model_pred_ref) + elif text_guidance_scale > 1.0: + model_pred_uncond = self.predict( + t=t, + latents=latents, + prompt_embeds=negative_prompt_embeds, + freqs_cis=freqs_cis, + prompt_attention_mask=negative_prompt_attention_mask, + ref_image_hidden_states=None, + ) + model_pred = model_pred_uncond + text_guidance_scale * (model_pred - model_pred_uncond) + + latents = self.scheduler.step(model_pred, t, latents, return_dict=False)[0] + + latents = latents.to(dtype=dtype) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if step_func is not None: + step_func(i, self._num_timesteps) + + latents = latents.to(dtype=dtype) + if self.vae.config.scaling_factor is not None: + latents = latents / self.vae.config.scaling_factor + if self.vae.config.shift_factor is not None: + latents = latents + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + + return image + + def predict( + self, + t, + latents, + prompt_embeds, + freqs_cis, + prompt_attention_mask, + ref_image_hidden_states, + ): + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + batch_size, num_channels_latents, height, width = latents.shape + + optional_kwargs = {} + if 'ref_image_hidden_states' in set(inspect.signature(self.transformer.forward).parameters.keys()): + optional_kwargs['ref_image_hidden_states'] = ref_image_hidden_states + + model_pred = self.transformer( + latents, + timestep, + prompt_embeds, + freqs_cis, + prompt_attention_mask, + **optional_kwargs + ) + return model_pred diff --git a/modules/omnigen2/pipeline_utils.py b/modules/omnigen2/pipeline_utils.py new file mode 100644 index 000000000..4efebc260 --- /dev/null +++ b/modules/omnigen2/pipeline_utils.py @@ -0,0 +1,62 @@ +import torch + + +def get_pipeline_embeds(pipeline, prompt, negative_prompt, device): + """ Get pipeline embeds for prompts bigger than the maxlength of the pipe + :param pipeline: + :param prompt: + :param negative_prompt: + :param device: + :return: + """ + max_length = pipeline.tokenizer.model_max_length + + # simple way to determine length of tokens + # count_prompt = len(prompt.split(" ")) + # count_negative_prompt = len(negative_prompt.split(" ")) + + # create the tensor based on which prompt is longer + # if count_prompt >= count_negative_prompt: + input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding='longest').input_ids.to(device) + # input_ids = pipeline.tokenizer(prompt, padding="max_length", + # max_length=pipeline.tokenizer.model_max_length, + # truncation=True, + # return_tensors="pt",).input_ids.to(device) + shape_max_length = input_ids.shape[-1] + + if negative_prompt is not None: + negative_ids = pipeline.tokenizer(negative_prompt, truncation=True, padding="max_length", + max_length=shape_max_length, return_tensors="pt").input_ids.to(device) + + # else: + # negative_ids = pipeline.tokenizer(negative_prompt, return_tensors="pt", truncation=False).input_ids.to(device) + # shape_max_length = negative_ids.shape[-1] + # input_ids = pipeline.tokenizer(prompt, return_tensors="pt", truncation=False, padding="max_length", + # max_length=shape_max_length).input_ids.to(device) + + concat_embeds = [] + neg_embeds = [] + for i in range(0, shape_max_length, max_length): + if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask: + attention_mask = input_ids[:, i: i + max_length].attention_mask.to(device) + else: + attention_mask = None + concat_embeds.append(pipeline.text_encoder(input_ids[:, i: i + max_length], + attention_mask=attention_mask)[0]) + + if negative_prompt is not None: + if hasattr(pipeline.text_encoder.config, "use_attention_mask") and pipeline.text_encoder.config.use_attention_mask: + attention_mask = negative_ids[:, i: i + max_length].attention_mask.to(device) + else: + attention_mask = None + neg_embeds.append(pipeline.text_encoder(negative_ids[:, i: i + max_length], + attention_mask=attention_mask)[0]) + + concat_embeds = torch.cat(concat_embeds, dim=1) + + if negative_prompt is not None: + neg_embeds = torch.cat(neg_embeds, dim=1) + else: + neg_embeds = None + + return concat_embeds, neg_embeds diff --git a/modules/omnigen2/triton_layer_norm.py b/modules/omnigen2/triton_layer_norm.py new file mode 100644 index 000000000..51a70b990 --- /dev/null +++ b/modules/omnigen2/triton_layer_norm.py @@ -0,0 +1,1257 @@ +# Copyright (c) 2024, Tri Dao. +# Implement dropout + residual + layer_norm / rms_norm. + +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. + +import math + +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + + +from typing import Callable + + +def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): + def decorator(*args, **kwargs): + if cuda_amp_deprecated: + kwargs["device_type"] = "cuda" + return dec(*args, **kwargs) + return decorator + + +if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] + deprecated = True + from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] +else: + deprecated = False + from torch.cuda.amp import custom_fwd, custom_bwd + +custom_fwd = custom_amp_decorator(custom_fwd, deprecated) +custom_bwd = custom_amp_decorator(custom_bwd, deprecated) + + +def triton_autotune_configs(): + # Return configs with a valid warp count for the current device + configs=[] + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 + max_threads_per_block=1024 + # Default to warp size 32 if not defined by device + warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit + warp_count=1 + while warp_count*warp_size <= max_threads_per_block: + configs.append(triton.Config({}, num_warps=warp_count)) + warp_count*=2 + return configs + +def layer_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + zero_centered_weight=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( + dtype + ) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = F.layer_norm( + x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps + ).to(dtype) + return (out, out1) if not prenorm else (out, out1, x) + + +def rms_norm_ref( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + zero_centered_weight=False, + dropout_mask=None, + dropout_mask1=None, + upcast=False, +): + dtype = x.dtype + if upcast: + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + residual = residual.float() if residual is not None else residual + x1 = x1.float() if x1 is not None else None + weight1 = weight1.float() if weight1 is not None else None + bias1 = bias1.float() if bias1 is not None else None + if zero_centered_weight: + weight = weight + 1.0 + if weight1 is not None: + weight1 = weight1 + 1.0 + if x1 is not None: + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + if rowscale is not None: + x = x * rowscale[..., None] + if dropout_p > 0.0: + if dropout_mask is not None: + x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p) + else: + x = F.dropout(x, p=dropout_p) + if x1 is not None: + if dropout_mask1 is not None: + x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p) + else: + x1 = F.dropout(x1, p=dropout_p) + if x1 is not None: + x = x + x1 + if residual is not None: + x = (x + residual).to(x.dtype) + rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) + out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype) + if weight1 is None: + return out if not prenorm else (out, x) + else: + out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to( + dtype + ) + return (out, out1) if not prenorm else (out, out1, x) + + +@triton.autotune( + configs=triton_autotune_configs(), + key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +@triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) +@triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) +@triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + X1, + W1, + B1, + Y1, + RESIDUAL_OUT, # pointer to the residual + ROWSCALE, + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + stride_x1_row, + stride_y1_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability + zero_centered_weight, # If true, add 1.0 to the weight + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_X1: tl.constexpr, + HAS_W1: tl.constexpr, + HAS_B1: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + if HAS_X1: + X1 += row * stride_x1_row + if HAS_W1: + Y1 += row * stride_y1_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + x *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) + if HAS_X1: + x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) + x1 *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) + x += x1 + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + # Write output + tl.store(Y + cols, y, mask=mask) + if HAS_W1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + if HAS_B1: + b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) + y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 + tl.store(Y1 + cols, y1, mask=mask) + + +def _layer_norm_fwd( + x, + weight, + bias, + eps, + residual=None, + x1=None, + weight1=None, + bias1=None, + dropout_p=0.0, + rowscale=None, + out_dtype=None, + residual_dtype=None, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None +): + if residual is not None: + residual_dtype = residual.dtype + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if x1 is not None: + assert x1.shape == x.shape + assert rowscale is None + assert x1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + else: + assert out.shape == x.shape + assert out.stride(-1) == 1 + if weight1 is not None: + y1 = torch.empty_like(out) + assert y1.stride(-1) == 1 + else: + y1 = None + if ( + residual is not None + or (residual_dtype is not None and residual_dtype != x.dtype) + or dropout_p > 0.0 + or rowscale is not None + or x1 is not None + ): + if residual_out is None: + residual_out = torch.empty( + M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype + ) + else: + assert residual_out.shape == x.shape + assert residual_out.stride(-1) == 1 + else: + residual_out = None + mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if dropout_p > 0.0: + seeds = torch.randint( + 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 + ) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask = None + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[(M,)]( + x, + out, + weight, + bias, + residual, + x1, + weight1, + bias1, + y1, + residual_out, + rowscale, + seeds, + dropout_mask, + mean, + rstd, + x.stride(0), + out.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + x1.stride(0) if x1 is not None else 0, + y1.stride(0) if y1 is not None else 0, + M, + N, + eps, + dropout_p, + zero_centered_weight, + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + bias is not None, + dropout_p > 0.0, + dropout_mask is not None, + rowscale is not None, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if dropout_mask is not None and x1 is not None: + dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) + else: + dropout_mask1 = None + return ( + out, + y1, + mean, + rstd, + residual_out if residual_out is not None else x, + seeds, + dropout_mask, + dropout_mask1, + ) + + +@triton.autotune( + configs=triton_autotune_configs(), + key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], +) +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) +# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) +@triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) +@triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) +@triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) +@triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DRESIDUAL, + W1, + DY1, + DX1, + DW1, + DB1, + DRESIDUAL_IN, + ROWSCALE, + SEEDS, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dres_row, + stride_dy1_row, + stride_dx1_row, + stride_dres_in_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, + zero_centered_weight, + rows_per_program, + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_DRESIDUAL: tl.constexpr, + STORE_DRESIDUAL: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_DY1: tl.constexpr, + HAS_DX1: tl.constexpr, + HAS_B1: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + # Do not early exit if row_start >= M, because we need to write DW and DB + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += row_start * stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += row_start * stride_dres_in_row + DY += row_start * stride_dy_row + DX += row_start * stride_dx_row + if HAS_DY1: + DY1 += row_start * stride_dy1_row + if HAS_DX1: + DX1 += row_start * stride_dx1_row + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if RECOMPUTE_OUTPUT and HAS_BIAS: + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + if HAS_DY1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_DY1: + dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_B1: + db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if HAS_DY1: + dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + dw += dy * xhat + if HAS_BIAS: + db += dy + if HAS_DY1: + wdy += w1 * dy1 + dw1 += dy1 * xhat + if HAS_B1: + db1 += dy1 + if not IS_RMS_NORM: + c1 = tl.sum(xhat * wdy, axis=0) / N + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + c1 = tl.sum(xhat * wdy, axis=0) / N + dx = (wdy - xhat * c1) * rstd + if HAS_DRESIDUAL: + dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) + dx += dres + # Write dx + if STORE_DRESIDUAL: + tl.store(DRESIDUAL_IN + cols, dx, mask=mask) + if HAS_DX1: + if HAS_DROPOUT: + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + else: + dx1 = dx + tl.store(DX1 + cols, dx1, mask=mask) + if HAS_DROPOUT: + keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + dx *= rowscale + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_DRESIDUAL: + DRESIDUAL += stride_dres_row + if STORE_DRESIDUAL: + DRESIDUAL_IN += stride_dres_in_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + if HAS_DY1: + DY1 += stride_dy1_row + if HAS_DX1: + DX1 += stride_dx1_row + tl.store(DW + row_block_id * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * N + cols, db, mask=mask) + if HAS_DY1: + tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) + if HAS_B1: + tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + dresidual=None, + dy1=None, + weight1=None, + bias1=None, + seeds=None, + dropout_p=0.0, + rowscale=None, + has_residual=False, + has_x1=False, + zero_centered_weight=False, + is_rms_norm=False, + x_dtype=None, + recompute_output=False, +): + M, N = x.shape + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if dresidual is not None: + assert dresidual.stride(-1) == 1 + assert dresidual.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if dy1 is not None: + assert weight1 is not None + assert dy1.shape == dy.shape + assert dy1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if seeds is not None: + assert seeds.is_contiguous() + assert seeds.shape == (M if not has_x1 else M * 2,) + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + # allocate output + dx = ( + torch.empty_like(x) + if x_dtype is None + else torch.empty(M, N, dtype=x_dtype, device=x.device) + ) + dresidual_in = ( + torch.empty_like(x) + if has_residual + and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) + else None + ) + dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None + y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None + if recompute_output: + assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the + # latency of the gmem reads/writes, but will increase the time of summing up dw / db. + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 + _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) + _db = ( + torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) + if bias is not None + else None + ) + _dw1 = torch.empty_like(_dw) if weight1 is not None else None + _db1 = torch.empty_like(_db) if bias1 is not None else None + rows_per_program = math.ceil(M / sm_count) + grid = (sm_count,) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + y, + dy, + dx, + _dw, + _db, + dresidual, + weight1, + dy1, + dx1, + _dw1, + _db1, + dresidual_in, + rowscale, + seeds, + mean, + rstd, + x.stride(0), + 0 if not recompute_output else y.stride(0), + dy.stride(0), + dx.stride(0), + dresidual.stride(0) if dresidual is not None else 0, + dy1.stride(0) if dy1 is not None else 0, + dx1.stride(0) if dx1 is not None else 0, + dresidual_in.stride(0) if dresidual_in is not None else 0, + M, + N, + eps, + dropout_p, + zero_centered_weight, + rows_per_program, + is_rms_norm, + BLOCK_N, + dresidual is not None, + dresidual_in is not None, + bias is not None, + dropout_p > 0.0, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None + db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None + # Don't need to compute dresidual_in separately in this case + if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: + dresidual_in = dx + if has_x1 and dropout_p == 0.0: + dx1 = dx + return ( + (dx, dw, db, dresidual_in, dx1, dw1, db1) + if not recompute_output + else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) + ) + + +class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None + ): + x_shape_og = x.shape + # Check for zero sequence length + if x.numel() == 0: + ctx.zero_seq_length = True + # Only save minimal required tensors for backward + # ctx.save_for_backward(weight, bias, weight1, bias1) + ctx.x_shape_og = x_shape_og + ctx.weight_shape = weight.shape + ctx.weight_dtype = weight.dtype + ctx.weight_device = weight.device + + ctx.has_bias = bias is not None + ctx.bias_shape = bias.shape if bias is not None else None + ctx.bias_dtype = bias.dtype if bias is not None else None + ctx.bias_device = bias.device if bias is not None else None + + ctx.has_weight1 = weight1 is not None + ctx.weight1_shape = weight1.shape if weight1 is not None else None + ctx.weight1_dtype = weight1.dtype if weight1 is not None else None + ctx.weight1_device = weight1.device if weight1 is not None else None + + ctx.has_bias1 = bias1 is not None + ctx.bias1_shape = bias1.shape if bias1 is not None else None + ctx.bias1_dtype = bias1.dtype if bias1 is not None else None + ctx.bias1_device = bias1.device if bias1 is not None else None + + ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None + ctx.dropout_p = dropout_p + + # Handle output tensors with correct dtype + y = x # Preserve input tensor properties + y1 = torch.empty_like(x) if x1 is not None else None + + # Only create residual_out if prenorm is True + residual_out = torch.empty(x.shape, + dtype=torch.float32 if residual_in_fp32 else x.dtype, + device=x.device) if prenorm else None + + # Handle dropout masks + dropout_mask = None + dropout_mask1 = None + if return_dropout_mask: + dropout_mask = torch.empty_like(x, dtype=torch.uint8) + if x1 is not None: + dropout_mask1 = torch.empty_like(x, dtype=torch.uint8) + + # Return based on configuration + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ((y, dropout_mask, dropout_mask1) if not prenorm + else (y, residual_out, dropout_mask, dropout_mask1)) + else: + return ((y, y1, dropout_mask, dropout_mask1) if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1)) + + ctx.zero_seq_length = False + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + if x1 is not None: + assert x1.shape == x_shape_og + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + x1 = x1.reshape(-1, x1.shape[-1]) + if x1.stride(-1) != 1: + x1 = x1.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + if weight1 is not None: + weight1 = weight1.contiguous() + if bias1 is not None: + bias1 = bias1.contiguous() + if rowscale is not None: + rowscale = rowscale.reshape(-1).contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + if out is not None: + out = out.reshape(-1, out.shape[-1]) + if residual_out is not None: + residual_out = residual_out.reshape(-1, residual_out.shape[-1]) + y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + x1, + weight1, + bias1, + dropout_p=dropout_p, + rowscale=rowscale, + residual_dtype=residual_dtype, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + out=out, + residual_out=residual_out + ) + ctx.save_for_backward( + residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd + ) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.dropout_p = dropout_p + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.has_x1 = x1 is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.zero_centered_weight = zero_centered_weight + y = y.reshape(x_shape_og) + y1 = y1.reshape(x_shape_og) if y1 is not None else None + residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None + dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None + dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None + if not return_dropout_mask: + if weight1 is None: + return y if not prenorm else (y, residual_out) + else: + return (y, y1) if not prenorm else (y, y1, residual_out) + else: + if weight1 is None: + return ( + (y, dropout_mask, dropout_mask1) + if not prenorm + else (y, residual_out, dropout_mask, dropout_mask1) + ) + else: + return ( + (y, y1, dropout_mask, dropout_mask1) + if not prenorm + else (y, y1, residual_out, dropout_mask, dropout_mask1) + ) + + @staticmethod + def backward(ctx, dy, *args): + if ctx.zero_seq_length: + return ( + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device), + torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device), + torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None, + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None, + torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None, + torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None, + torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if weight1 is not None: + dy1, args = args[0], args[1:] + dy1 = dy1.reshape(-1, dy1.shape[-1]) + if dy1.stride(-1) != 1: + dy1 = dy1.contiguous() + assert dy1.shape == x.shape + else: + dy1 = None + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + + dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( + dy, + x, + weight, + bias, + ctx.eps, + mean, + rstd, + dresidual, + dy1, + weight1, + bias1, + seeds, + ctx.dropout_p, + rowscale, + ctx.has_residual, + ctx.has_x1, + ctx.zero_centered_weight, + ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, + dw1, + db1, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out=None, + residual_out=None +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + is_rms_norm, + return_dropout_mask, + out, + residual_out + ) + + +def rms_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + return_dropout_mask=False, + out=None, + residual_out=None +): + return LayerNormFn.apply( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + True, + return_dropout_mask, + out, + residual_out + ) + + +class RMSNorm(torch.nn.Module): + + def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False, + device=None, dtype=None): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if dropout_p > 0.0: + self.drop = torch.nn.Dropout(dropout_p) + else: + self.drop = None + self.zero_centered_weight = zero_centered_weight + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + if not self.zero_centered_weight: + torch.nn.init.ones_(self.weight) + else: + torch.nn.init.zeros_(self.weight) + + def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): + return rms_norm_fn( + x, + self.weight, + self.bias, + residual=residual, + eps=self.eps, + dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, + prenorm=prenorm, + residual_in_fp32=residual_in_fp32, + zero_centered_weight=self.zero_centered_weight, + ) + + +class LayerNormLinearFn(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward( + ctx, + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if residual is not None: + assert residual.shape == x_shape_og + residual = residual.reshape(-1, residual.shape[-1]) + if residual.stride(-1) != 1: + residual = residual.contiguous() + norm_weight = norm_weight.contiguous() + if norm_bias is not None: + norm_bias = norm_bias.contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd( + x, + norm_weight, + norm_bias, + eps, + residual, + out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"), + residual_dtype=residual_dtype, + is_rms_norm=is_rms_norm, + ) + y = y.reshape(x_shape_og) + dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype + linear_weight = linear_weight.to(dtype) + linear_bias = linear_bias.to(dtype) if linear_bias is not None else None + out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) + # We don't store y, will be recomputed in the backward pass to save memory + ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.is_rms_norm = is_rms_norm + ctx.has_residual = residual is not None + ctx.prenorm = prenorm + ctx.x_dtype = x.dtype + ctx.linear_bias_is_none = linear_bias is None + return out if not prenorm else (out, residual_out.reshape(x_shape_og)) + + @staticmethod + @custom_bwd + def backward(ctx, dout, *args): + x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors + dout = dout.reshape(-1, dout.shape[-1]) + dy = F.linear(dout, linear_weight.t()) + dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + if ctx.prenorm: + dresidual = args[0] + dresidual = dresidual.reshape(-1, dresidual.shape[-1]) + if dresidual.stride(-1) != 1: + dresidual = dresidual.contiguous() + assert dresidual.shape == x.shape + else: + dresidual = None + dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd( + dy, + x, + norm_weight, + norm_bias, + ctx.eps, + mean, + rstd, + dresidual=dresidual, + has_residual=ctx.has_residual, + is_rms_norm=ctx.is_rms_norm, + x_dtype=ctx.x_dtype, + recompute_output=True, + ) + dlinear_weight = torch.einsum("bo,bi->oi", dout, y) + return ( + dx.reshape(ctx.x_shape_og), + dnorm_weight, + dnorm_bias, + dlinear_weight, + dlinear_bias, + dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, + None, + None, + None, + None, + ) + + +def layer_norm_linear_fn( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual=None, + eps=1e-6, + prenorm=False, + residual_in_fp32=False, + is_rms_norm=False, +): + return LayerNormLinearFn.apply( + x, + norm_weight, + norm_bias, + linear_weight, + linear_bias, + residual, + eps, + prenorm, + residual_in_fp32, + is_rms_norm, + ) diff --git a/modules/processing_args.py b/modules/processing_args.py index 154f5e272..7721fe371 100644 --- a/modules/processing_args.py +++ b/modules/processing_args.py @@ -61,7 +61,7 @@ def task_specific_kwargs(p, model): p.width, p.height = p.width // vae_scale_factor * vae_scale_factor, p.height // vae_scale_factor * vae_scale_factor task_args['max_area'] = max_area task_args['width'], task_args['height'] = p.width, p.height - if model.__class__.__name__ == 'OmniGenPipeline': + elif model.__class__.__name__ == 'OmniGenPipeline' or model.__class__.__name__ == 'OmniGen2Pipeline': p.width, p.height = 16 * math.ceil(p.init_images[0].width / 16), 16 * math.ceil(p.init_images[0].height / 16) task_args = { 'width': p.width, @@ -285,7 +285,7 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t kwargs['output_type'] = 'np' # only set latent if model has vae # model specific - if 'Kandinsky' in model.__class__.__name__: + if 'Kandinsky' in model.__class__.__name__ or 'Cosmos2' in model.__class__.__name__ or 'OmniGen2' in model.__class__.__name__: kwargs['output_type'] = 'np' # only set latent if model has vae if 'StableCascade' in model.__class__.__name__: kwargs.pop("guidance_scale") # remove @@ -305,8 +305,6 @@ def set_pipeline_args(p, model, prompts:list, negative_prompts:list, prompts_2:t args['control_strength'] = p.denoising_strength args['width'] = p.width args['height'] = p.height - if 'Cosmos2' in model.__class__.__name__: - kwargs['output_type'] = 'np' # cosmos uses wan-vae which is weird # set callbacks if 'prior_callback_steps' in possible: # Wuerstchen / Cascade args['prior_callback_steps'] = 1 diff --git a/modules/sd_detect.py b/modules/sd_detect.py index cb99f8d17..be57edb4b 100644 --- a/modules/sd_detect.py +++ b/modules/sd_detect.py @@ -88,6 +88,9 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False): if 'omnigen' in f.lower(): guess = 'OmniGen' pipeline = 'custom' + if 'omnigen2' in f.lower(): + guess = 'OmniGen2' + pipeline = 'custom' if 'sd3' in f.lower(): guess = 'Stable Diffusion 3' if 'hidream' in f.lower(): diff --git a/modules/sd_models.py b/modules/sd_models.py index 97e36b2e6..3a74eac56 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -37,6 +37,7 @@ pipe_switch_task_exclude = [ 'InstantIRPipeline', 'LTXConditionPipeline', 'OmniGenPipeline', + 'OmniGen2Pipeline', 'PhotoMakerStableDiffusionXLPipeline', 'PixelSmithXLPipeline', 'StableDiffusion3ControlNetPipeline', @@ -364,6 +365,10 @@ def load_diffuser_force(model_type, checkpoint_info, diffusers_load_config, op=' from modules.model_meissonic import load_meissonic sd_model = load_meissonic(checkpoint_info, diffusers_load_config) allow_post_quant = True + elif model_type in ['OmniGen2']: # forced pipeline + from modules.model_omnigen2 import load_omnigen2 + sd_model = load_omnigen2(checkpoint_info, diffusers_load_config) + allow_post_quant = False elif model_type in ['OmniGen']: # forced pipeline from modules.model_omnigen import load_omnigen sd_model = load_omnigen(checkpoint_info, diffusers_load_config) diff --git a/modules/sd_offload.py b/modules/sd_offload.py index 188275f71..037f46ca2 100644 --- a/modules/sd_offload.py +++ b/modules/sd_offload.py @@ -12,7 +12,7 @@ from modules.timer import process as process_timer debug = os.environ.get('SD_MOVE_DEBUG', None) is not None debug_move = log.trace if debug else lambda *args, **kwargs: None -offload_warn = ['sc', 'sd3', 'f1', 'h1', 'hunyuandit', 'auraflow', 'omnigen', 'cogview4', 'cosmos', 'chroma'] +offload_warn = ['sc', 'sd3', 'f1', 'h1', 'hunyuandit', 'auraflow', 'omnigen', 'omnigen2', 'cogview4', 'cosmos', 'chroma'] offload_post = ['h1'] offload_hook_instance = None balanced_offload_exclude = ['CogView4Pipeline'] diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 2d6c2858c..15e2cf1ad 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -9,7 +9,7 @@ from modules import shared, devices, processing, images, sd_vae_approx, sd_vae_t SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) approximation_indexes = { "Simple": 0, "Approximate": 1, "TAESD": 2, "Full VAE": 3 } -flow_models = ['f1', 'sd3', 'lumina', 'auraflow', 'sana', 'lumina2', 'cogview4', 'h1', 'cosmos', 'chroma'] +flow_models = ['f1', 'sd3', 'lumina', 'auraflow', 'sana', 'lumina2', 'cogview4', 'h1', 'cosmos', 'chroma', 'omnigen2'] warned = False queue_lock = threading.Lock() diff --git a/modules/shared.py b/modules/shared.py index cc224dc00..d87d5b252 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -522,11 +522,11 @@ options_templates.update(options_section(("quantization", "Quantization Settings "sdnq_quantize_weights_mode": OptionInfo("int8", "Quantization type", gr.Dropdown, {"choices": sdnq_quant_modes, "visible": native}), "sdnq_quantize_weights_mode_te": OptionInfo("default", "Quantization type for Text Encoders", gr.Dropdown, {"choices": ['default'] + sdnq_quant_modes, "visible": native}), "sdnq_quantize_weights_group_size": OptionInfo(0, "Group size", gr.Slider, {"minimum": -1, "maximum": 4096, "step": 1, "visible": native}), - "sdnq_quantize_conv_layers": OptionInfo(False, "Quantize the convolutional layers", gr.Checkbox, {"visible": native}), + "sdnq_quantize_conv_layers": OptionInfo(False, "Quantize convolutional layers", gr.Checkbox, {"visible": native}), "sdnq_dequantize_compile": OptionInfo(devices.has_triton(), "Dequantize using torch.compile", gr.Checkbox, {"visible": native}), - "sdnq_use_quantized_matmul": OptionInfo(False, "Use Quantized MatMul", gr.Checkbox, {"visible": native}), - "sdnq_use_quantized_matmul_conv": OptionInfo(False, "Use Quantized MatMul with convolutional layers", gr.Checkbox, {"visible": native}), - "sdnq_quantize_with_gpu": OptionInfo(True, "Quantize with the GPU", gr.Checkbox, {"visible": native}), + "sdnq_use_quantized_matmul": OptionInfo(False, "Use quantized MatMul", gr.Checkbox, {"visible": native}), + "sdnq_use_quantized_matmul_conv": OptionInfo(False, "Use quantized MatMul with convolutional layers", gr.Checkbox, {"visible": native}), + "sdnq_quantize_with_gpu": OptionInfo(True, "Quantize using GPU", gr.Checkbox, {"visible": native}), "sdnq_dequantize_fp32": OptionInfo(False, "Dequantize using full precision", gr.Checkbox, {"visible": native}), "sdnq_quantize_shuffle_weights": OptionInfo(False, "Shuffle weights in post mode", gr.Checkbox, {"visible": native}), diff --git a/wiki b/wiki index c1ae36ab4..711b61ebf 160000 --- a/wiki +++ b/wiki @@ -1 +1 @@ -Subproject commit c1ae36ab4c2197487306c5c9fe4e328a038d1367 +Subproject commit 711b61ebfdb8cb09b06b4f8c627ae97519c6f74c