Jump to content

Discrete diffusion model

From Wikipedia, the free encyclopedia

In machine learning, discrete diffusion models are a class of diffusion models, which themselves are a class of latent variable generative models. Each discrete diffusion model consists of two major components: the forward jump diffusion process, and the reverse jump diffusion process. The goal of diffusion modeling is, given a given dataset and a forward process, to learn a model for the reverse process, such that the reverse process can generate new elements that are distributed similarly as the original dataset. A trained discrete diffusion model can be sampled in many ways, which trades off computational efficiency and sample quality. In general, higher quality data can be obtained, but at the price of higher computational cost.

In standard diffusion modeling, the diffusion process takes place over a state space that is continuous space of , but over a discrete set . A discrete set is simply a set where one cannot speak of "infinitesimally close" points. Points can be more or less separated from each other, but the separation is always a finite number. This in particular means the standard framework of continuous diffusion does not apply, since it uses gaussian noise, which is continuous. Nevertheless, an analogous theory can be produced.

Discrete diffusion is usually used for language modeling.[1][2] In practice, the state space is not only discrete, but finite, so this is what we will assume from now on.

Continuous time Markov process

[edit]

In the case of continuous state space, during the forward discrete diffusion process, at each step , we mix in an infinitesimal amount of gaussian noise . This changes the probability density function, by first a convolution with the density of a gaussian, followed by a scaling.

In the case of discrete state space, the gaussian noise must be replaced by a noise that takes values over a finite set. For example, if the noise is the uniform distribution over , then the probability distribution at time satisfiesMore succinctly,In general, we do not need to convolve with a uniformly distributed noise, but with an arbitrary noise process. That is, we use an arbitrary matrix such thatwhere is called the rate matrix. Any matrix may be used as a rate matrix if it has non-negative off-diagonals, and each column sums to 0:A continuous time Markov chain (CTMC) is defined by a continuous function that maps any time to a rate matrix . Given the function , time-evolution under the CTMC is done as follows: Given state at time , and given an infinitesimal , the state at is , such thatThis implies that the probability distribution function evolves according towhich is what we previously specified.

Backward process

[edit]

Similarly to the case of continuous diffusion, in discrete diffusion, there exists a backward diffusion process :where should be interpreted as the discrete score or concrete score, since, abusing notation a bit, the score function is .

If we picture the distribution as a bunch of point-masses, one per state , then the forward diffusion from time to is performed by removing from the mass at and moving it to the mass at , for each pair . Thus, the process is reversed in detail by the CTMC defined by , since .

Given , if we have a way to sample from , then we can sample from by first sampling , then sampling according to

Overall plan of score-matching discrete diffusion modeling

[edit]

Similar to score-matching continuous diffusion, score-matching discrete diffusion is a method to sample an initial distribution.

If we have a certain function that approximates the true score function , then it allows a corresponding to be defined in the same way.

If we also have a base distribution such that it is easy to sample from, and approximately equal to the true terminal distribution , then we can perform the backward CTMC with and .

When both approximations are good, the backward CTMC would give . This is the idea of score-matching discrete diffusion modeling.

If is sharp, in the sense that for some , we have , then the score function would diverge as at the limit. To avoid this in practice, it is common to use early stopping, which is to stop the backward process at some time , and sample from instead of .

Tractable forward processes

[edit]

The theory of CTMC works for any continuous choice of rate matrices . However, most choices are computationally expensive and cannot be used in practice.

In the case of continuous diffusion, the gaussian noise is used for the simple reason that the sum of any number of gaussians is still a gaussian. This allows one to sample any by sampling a single , followed by a single gaussian noise , and let , without needing any for any .

Similarly, the choice of rate matrices should also allow us to "skip forward" without needing any intermediate steps.

The uniform noising process is defined byTo see how to skip forward, note that the uniform noising process is equivalent to the following process: to evolve from time to time , either don't change anything with probability , or sample a random state uniformly with probability . Since sampling uniform random states twice is the same as sampling it once, we see that the only question is whether we have ever sampled a random state. As time goes on, the probability of not sampling a random state decays exponentially. Therefore,Time may be rescaled as desired. For example, the CTMC defined for is produced by the scaled-time uniform noising processThe construction works in general for arbitrary rescaled time, and arbitrary noisy distribution on . The idea is that if we have a fixed reference noise distribution , then sampling from it twice is the same as sampling from it once. Therefore, the noising process produces . Here, the time-rescaling function must be strictly monotonic.

More generally, if , then is the -th column of the matrix . If , then .

In practice, when is sufficiently large, only two processes are efficient enough for training in practice: the uniform process and the absorbing process:The uniform process is simply the uniform noising process with a rescalable time, converging to the uniform distribution on .

The absorbing process means that there is a special absorbing state , converging to the point distribution on the absorbing state . During time , if the state , then it transitions into with probability , and stays unchanged with probability , but if the , then always. In diffusion language modeling, that special state is usually called [MASK], which originated from masked language modeling.

Score matching

[edit]

A discrete diffusion model is usually a score-matching network. That is, it is a neural network that takes as input , and approximately computes the discrete score function:where is the weights of the network. Once some good weights are found, the score network can be used to produce the backward diffusion processand produce samples that are approximately distributed as . There are different algorithms for training a score-matching network.

The concrete score matching algorithm minimizes the L2 loss by stochastic gradient descentwhere the outer expectation means averaging over a randomly sampled time-instance. For example, if we allow in the definition of the forward CTMC process, then a common choice is to sample during training.

The L2 loss has the problem that should never be negative, but the L2 loss does not prevent from becoming negative.

SEDD

[edit]

The Score Entropy Discrete Diffusion (SEDD) algorithm[3] minimizes a certain score entropy loss:where is just a function, and is an arbitrary array of positive numbers that can be adjusted as hyperparameters of the training algorithm. In the next section, we will see that setting is a z

The expression within the brackets is proportional to the Bregman divergence for , the negative logarithmic function:Since is used in definition of entropy, this explains why is called the "score entropy loss". Since the loss approaches infinity where approaches zero, the score entropy loss prevents negative values of

Since the Bregman divergence is zero only when the two terms are equal, the score entropy loss is minimized to a value of zero iff the score matching is perfect: .

There are 2 losses equivalent to SEDD. The implicit score entropy loss is which is equal to , where is independent of , and therefore optimization of is equivalent to the optimization of . However, the ISE loss requires evaluating the score network for times per sample of . This does not scale.

The denoising score entropy loss is which is equal to . This can be derived by using the identity . Since it only evaluates the score network once per sample of . This does scale.

Variational inference

[edit]

The score entropy objectives can be cast into the variational inference form. Given a base distribution on and a backward CTMC defined by and , let denote the resulting model distribution over . For a fixed data point , the diffusion weighted denoising score entropy (DWDSE) loss is defined aswhere is the forward CTMC kernel defined by , and is the function introduced above. It is minimized when for all such that . If is a sparse matrix, then the expression for accordingly can be simplified, since most state transitions are impossible.

For the diffusion and forward probabilities defined above,where is the Kullback–Leibler divergence. In particular, when , minimizing the expectation of over minimizes an upper bound on the expected negative log-likelihood .

Adaption to sequence modeling

[edit]

The most common application of discrete diffusion is for sequence modeling. For these, the discrete state space usually has a particular structure that can and must be exploited. For example, in language modeling, there are only finitely many different tokens allowed. The set of allowed tokens is called the vocabulary , and its size is the vocabulary size , which is always finite. For a given sequence length , the state space is the space of all length- sequences of tokens, which is of size .

Forward process

[edit]

Since the size of the state space grows exponentially with sequence length, it is too large to be directly modeled. For example, if a sequence has 10 tokens, and each token can be chosen from a list of 100 valid tokens, then the full state space has size , which is intractable.

Because of this, the standard method is to consider only tokenwise forward processes, i.e. those that factor into independent forward processes over each token. Tokenwise forward processes do not need each token to undergo the same forward process, though in practice, often all tokens undergo the same forward process.

Let the sequence have tokens. Let index over the tokens in the sequence, so that . Let the forward process for token be defined by the rate matrix , then the rate matrix for the full sequence satisfiesIntuitively, the rate matrices are simply added together, since the probability that two jumps occur during the same infinitesimal slice of time is on the order of , which is infinitesimal compared to the probability that one jump occurs, which is on the order of .

In language modeling, usually the vocabulary size is on the order of 100,000, which means that an arbitrary matrix is too large to fit into memory, so the only case in common use is where all tokens use the exact same rate matrix , which is equal to one of the aforementioned cases . That is, there exists some function such that for all token indices and all times .

Given this set up, the forward process factors tokenwiseNote that the factorization is conditional on . Without the conditioning, it fails, because the initial distribution does not factor tokenwise as . Thus, the following does not factorize: the backward process , the marginalized distribution , and the score .

Score function

[edit]

In general, the discrete score is not tokenwise, i.e. in general, there does not exist some function such thatNevertheless, in this case, variational inference allows a simplification. Specifically, since when differ at more than 1 token, the summation in the definition of need only include the cases where differ at exactly 1 token. That is, the training process need only minimize the following loss:where in the notation, means the sequence obtained by replacing the -th entry of by . Consequently, the score-matching model need to output only scores for each 1-token modification, instead of for each full-sequence modification. If is sparse, then the expression can be simplified further, such as when , since most single-token transitions are impossible.

The theoretical minimum is achieved when for all such that and .

Backward process

[edit]

Ideally, if the model learns the score exactly, then it defines a backward diffusion process that exactly reverses the forward diffusion process. Its rate matrix isIntuitively, since the forward process only changes 0 or 1 tokens at any moment in time, so does the backward process.

However, if the score is not exactly matched , then it produces a score-matching error.

Furthermore, the backward process in practice cannot be performed in continuous time, but only in discrete time. This produces a time-discretization error. The Gillespie algorithm cannot in general perform the backward process exactly, since for fixed , the score changes as changes. That is, the backward CTMC cannot be solved exactly as a backward discrete-time Markov chain. This contrasts with the forward process, where for , the forward CTMC is exactly solvable as a discrete-time Markov chain. This is similar to how in continuous diffusion, the forward diffusion is exactly computable at discrete time instances, but the backward process requires integration over continuous time.

Using the Gillespie algorithm, or other discrete-time algorithms, produces a Euler method approximation error. It can be improved by using better stochastic integrators and using more integration steps in the backward process.

Similar to how the Gillespie algorithm can be accelerated by tau-leaping, the backward process can be accelerated by changing more than 1 token per discretized time-step.

Given the noising process , we have . By Bayes's theorem, it is inverted by the discrete Tweedie's formula:This gives the Tweedie tau-leaping algorithm, which, for each backward time-interval , randomly and independently transition each token . That is, for each token individually, sample independently of the other tokens, with the probability of transitioning from to equal towhere .

The lower accuracy of tau-leaping is due to the approximation by tokenwise independence. In general, given , the score . Concretely, suppose that in an exact backward simulation, exactly two tokens changes to during a backward time-interval , then it matters whether tokens or token changed first, since that affects the rate of change at the other token.

Conditional backward process

[edit]

The above framework considers unconditional generation. That is, sampling a full sequence approximately from the sample distribution . Certain tasks require generating part of a sequence while holding other parts of a sequence fixed. For example, in prompt engineering or few-shot learning, the first few sentences are fixed, and only the rest of the sequence can be changed.

In general, let be a partition of , then the problem is to generate conditional on , that is, sampling . By Bayes's theorem,when . Thus, a score-matching model for unconditional sequence generation is also a score-matching model for the conditional case:for any such that . Thus, all previous conditional backward process sampling algorithms still work, simply by fixing .

Error analysis

[edit]

In general, the backward process of a discrete diffusion samples a probability distribution that differs from . This creates the following sources of error:

  • The mixing error, due to beginning the backward diffusion process at rather than .
  • The score-matching error, due to using rather than to define the rate matrix of the backward diffusion process.
  • The time-discretization error, due to using discrete time, not continuous time, when integrating through the backward diffusion process. This is analogous to the error of the Euler algorithm.
  • The independence error, due to using tau-leaping, which erroneously allows more than one token to change at a time in a probabilistically independent way.
  • The early stopping error, due to ending the backward diffusion process at rather than at .

Each error can be decreased or eliminated, usually at the price of increased cost of compute.

  • Mixing error can be decreased by running the diffusion time for longer, or eliminated by the analog of the "zero SNR" fix for diffusion in continuous space.[4]
  • Score-matching error can be decreased by better training of the score-matching model.
  • Time-discretization error can be decreased by making smaller time-steps, or essentially eliminated by numerical integration plus the uniformization trick for CTMC.[5]
  • Independence error can be eliminated by only changing one token at a time during backward diffusion.
  • Early stopping can be eliminated if the score does not diverge as . If the score does diverge, then the stopping time can be decreased, but the time-steps would need to decrease in tandem.

See also

[edit]

Further reading

[edit]

References

[edit]
  1. ^ Gulrajani, Ishaan; Hashimoto, Tatsunori B. (2023-12-15). "Likelihood-Based Diffusion Language Models". Advances in Neural Information Processing Systems. 36: 16693–16715. arXiv:2305.18619.
  2. ^ Campbell, Andrew; Benton, Joe; De Bortoli, Valentin; Rainforth, Thomas; Deligiannidis, George; Doucet, Arnaud (2022-12-06). "A Continuous Time Framework for Discrete Denoising Models". Advances in Neural Information Processing Systems. 35: 28266–28279.
  3. ^ Lou, Aaron; Meng, Chenlin; Ermon, Stefano (2024-06-06). "Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution". arXiv:2310.16834 [stat.ML].
  4. ^ Lin, Shanchuan; Liu, Bingchen; Li, Jiashi; Yang, Xiao (2024). Common Diffusion Noise Schedules and Sample Steps Are Flawed. IEEE/CVF Winter Conference on Applications of Computer Vision (WACV). pp. 5404–5411.
  5. ^ Chen, Hongrui; Ying, Lexing (2024-02-14), Convergence Analysis of Discrete Diffusion Model: Exact Implementation through Uniformization, arXiv, doi:10.48550/arXiv.2402.08095, arXiv:2402.08095