arxiv:2409.18523
2024-10-16
Use a small learnable network g_{θ} dubbed Cache Predictor Lou et al. (2024) to predict the importance of the tokens w^{t}_{l} = g_{θ}(l, t) \in \mathbf{R}^{n}, and prune the tokens based on their relative importance.
How to learn g_{θ}?
\hat{f}(z^{t}_{l}) = \textcolor{#800080}{w^{t}_{l} ⦿ f_{l}(z^{t}_{l})} + \textcolor{#008080}{(1 - w^{t}_{l}) ⦿ f_{l}(z^{t+1}_{l})} where diffusion transformer contains total number of L network blocks, with a total of T inference timesteps and z^{t}_{l} \in \mathbf{R}^{n \times d} denotes input to block f_{l} at timestep t
🔍 Optimization Objective
ℒ_{\text{mse}} = 𝔼_{t, z^{t}_{L+1}, \hat{z}^{t}_{L+1}} [‖ z^{t}_{L+1} - \hat{z}^{t}_{L+1}‖^{2}_{2}]
Token Caching inherintly adds errors into the samplingtrajectory.
How to mitigate errors?
How to choose K?
TPPR