Improving Dense Retriever Training with Corrector Networks (2025)

Nicholas Monath  Will Grathwohl  Michael Boratko  Rob Fergus  Andrew McCallum  Manzil Zaheer

Abstract

In dense retrieval, deep encoders provide embeddings for both inputs and targets, and the softmax function is used to parameterize a distribution over a large number of candidate targets (e.g., textual passages for information retrieval). Significant challenges arise in training such encoders in the increasingly prevalent scenario of (1) a large number of targets, (2) a computationally expensive target encoder model, (3) cached target embeddings that are out-of-date due to ongoing training of target encoder parameters. This paper presents a simple and highly scalable response to these challenges by training a small parametric corrector network that adjusts stale cached target embeddings, enabling an accurate softmax approximation and thereby sampling of up-to-date high scoring “hard negatives.” We theoretically investigate the generalization properties of our proposed target corrector, relating the complexity of the network, staleness of cached representations, and the amount of training data. We present experimental results on large benchmark dense retrieval datasets as well as on QA with retrieval augmented language models. Our approach matches state-of-the-art results even when no target embedding updates are made during training beyond an initial cache from the unsupervised pre-trained model, providing a 4-80x reduction in re-embedding computational cost.

Retrieval, Negative Mining

1 Introduction

The softmax function, paired with deep neural encoder models,is often the parameterization of choice for discrete distributionsover many targets such as in classification (Logeswaran etal., 2019; Yu etal., 2022),retrieval (Reddi etal., 2019; Xiong etal., 2020), or reinforcement learning(Dulac-Arnold etal., 2015; Gottipati etal., 2020).This approach, often called a “dual encoder,” employs two separate deep networks, one to map an input to a fixed dimensional vector, another to map targets to the same vector space. We then compute softmax logits as the inner product of an input vector to each target vector (Gillick etal., 2019; Karpukhin etal., 2020; Xiong etal., 2020).

With the typical softmax cross-entropy loss, exact training of the parameters of these two encoder networks would involve using the current parameters to compute the logits for all targets, requiring running the target encoder on all targets at every step of training. Of course, this far-too-burdensome approach is not used in practice. Instead, various approximations have been developed (Reddi etal., 2019; Rawat etal., 2020; Lindgren etal., 2021; Xiong etal., 2020; Monath etal., 2023).The typical approximation computes a truncated softmax on a sampled subset of targets. These approaches store a cache of “stale” encoded representations of targets and uses the stale, cached representations to draw samples from the softmax-parameterized distribution during training (Lindgren etal., 2021; Izacard etal., 2022). Previous work has used these stale representations amidst other approximations such as index structures (Xiong etal., 2020; Monath etal., 2023), kernel-methods (Rawat etal., 2019), and focusing training on subsets of targets (Reddi etal., 2019). However, inevitably, the staleness of the target embeddings causes training regret.

In this work, we present a simple, general purpose method for addressing staleness in softmax-parameterized categorical distributions that is scalable enough to be updated at every step of training. Our approach improves upon an existing stale approximation using a learned target corrector network. The target corrector network, inspired by recent work on training continuous energy-based models (Han etal., 2020; Grathwohl etal., 2020, 2021), is a small parametric model that accounts for the discrepancy between the stale approximation and unnormalized logits from the true distribution. By learning to improve upon the stale approximation, the target corrector network can be used to produce a more accurate approximation to the target distribution. We further extend beyond training large output space classifiers to latent variable retrieval augmented language models.

In summary, the contributions of this paper are:

Methodological3) - We describe a novel training procedure for large output space models. It is based on approximating softmax-parameterized categorical distributions by using a parametric target corrector network that learns to improve stale approximations of logits.

Theoretical4) - We analyze the generalization properties of the corrector networks in terms of the discrepancy between the stale approximation and the true distribution, the complexity of the network, and the amount of training data.

Empirical5) -We evaluate our approach in training both dense retrieval models and latent variable retrieval augmented language models. Our approach matches the performance of much more computationally intensive approaches at a fraction of the computational expense.

2 Background

Softmax  Given an input point x𝑥xitalic_x, a distribution over a set of N𝑁Nitalic_N targets, 𝒴𝒴\mathcal{Y}caligraphic_Y, parameterized by the softmax function is:

P(y|x)=exp(βsx,y)Zxy𝒴exp(βsx,y),𝑃conditional𝑦𝑥𝛽subscript𝑠𝑥𝑦subscript𝑍𝑥subscriptsuperscript𝑦𝒴𝛽subscript𝑠𝑥superscript𝑦\vspace{-1mm}P(y|x)=\frac{\exp(\beta s_{x,y})}{Z_{x}\triangleq\sum_{y^{\prime}%\in\mathcal{Y}}\exp(\beta s_{x,y^{\prime}})},italic_P ( italic_y | italic_x ) = divide start_ARG roman_exp ( italic_β italic_s start_POSTSUBSCRIPT italic_x , italic_y end_POSTSUBSCRIPT ) end_ARG start_ARG italic_Z start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ≜ ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_Y end_POSTSUBSCRIPT roman_exp ( italic_β italic_s start_POSTSUBSCRIPT italic_x , italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) end_ARG ,(1)

where β𝛽\betaitalic_β is the temperature. In this paper, we focus on applications in retrieval and latent variable models. For example, in Natural Questions (Kwiatkowski etal., 2019), x𝑥xitalic_x refers to a question and targets, y𝑦yitalic_y, correspond to Wikipedia passages.

Dual-Encoders  We compute the unnormalized logits, sx,ysubscript𝑠𝑥𝑦s_{x,y}italic_s start_POSTSUBSCRIPT italic_x , italic_y end_POSTSUBSCRIPT, using a factorized representation. Deep parametric models, dual-encoders, map the input, x𝑥xitalic_x, and target, y𝑦yitalic_y, to D𝐷Ditalic_D-dimensional vectors, denoted f(x;Θ)𝑓𝑥Θf(x;\Theta)italic_f ( italic_x ; roman_Θ ), and g(y;Θ)𝑔𝑦Θg(y;\Theta)italic_g ( italic_y ; roman_Θ ):

sx,y=f(x;Θ),g(y;Θ).subscript𝑠𝑥𝑦𝑓𝑥Θ𝑔𝑦Θ\vspace{-2mm}s_{x,y}=\langle f(x;\Theta),g(y;\Theta)\rangle.\vspace{-2mm}italic_s start_POSTSUBSCRIPT italic_x , italic_y end_POSTSUBSCRIPT = ⟨ italic_f ( italic_x ; roman_Θ ) , italic_g ( italic_y ; roman_Θ ) ⟩ .(2)

Training  For a task-specific loss, \mathcal{L}caligraphic_L, such as cross-entropy, dual-encoder parameters are optimized by gradient descent (Rawat etal., 2019). However, exact computation of the normalizing constant, Zxsubscript𝑍𝑥Z_{x}italic_Z start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT, is typically intractable during training, since it would require computing g(y)𝑔𝑦g(y)italic_g ( italic_y ) for millions or billions of targets. Instead of P(y|x)𝑃conditional𝑦𝑥P(y|x)italic_P ( italic_y | italic_x ) in \mathcal{L}caligraphic_L, a tractable (yet biased) approximation is to optimize the truncated softmax, P~(y|x)~𝑃conditional𝑦𝑥\tilde{P}(y|x)over~ start_ARG italic_P end_ARG ( italic_y | italic_x ), including only a subset of targets S(𝒴)𝒴𝑆𝒴𝒴S(\mathcal{Y})\subset\mathcal{Y}italic_S ( caligraphic_Y ) ⊂ caligraphic_Y:

P~(y|x)=exp(βsx,y)yS(𝒴)exp(βsx,y),~𝑃conditional𝑦𝑥𝛽subscript𝑠𝑥𝑦subscriptsuperscript𝑦𝑆𝒴𝛽subscript𝑠𝑥superscript𝑦\tilde{P}(y|x)=\frac{\exp(\beta s_{x,y})}{\sum_{y^{\prime}\in S(\mathcal{Y})}%\exp(\beta s_{x,y^{\prime}})},over~ start_ARG italic_P end_ARG ( italic_y | italic_x ) = divide start_ARG roman_exp ( italic_β italic_s start_POSTSUBSCRIPT italic_x , italic_y end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_S ( caligraphic_Y ) end_POSTSUBSCRIPT roman_exp ( italic_β italic_s start_POSTSUBSCRIPT italic_x , italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) end_ARG ,(3)

Uniform Sampling Approximation  A simple approach is to define S(𝒴)𝑆𝒴S(\mathcal{Y})italic_S ( caligraphic_Y ) to be a uniformly sampled subset of 𝒴𝒴\mathcal{Y}caligraphic_Y (Karpukhin etal., 2020).The method’s bias decreases with more samples. However, since the samples are uniform, a large number of samples may be required.

Top-K / Similarity-based Sampling Approximations  We can instead use an informed strategy using g(y)𝑔𝑦g(y)italic_g ( italic_y ) that would select higher probability targets by sampling using similarity scores via Gumbel-Max (Lindgren etal., 2021), or using the top-k targets in terms of inner product (Xiong etal., 2020).Work has considered efficient approximations to find these top k𝑘kitalic_k targets without having to compute g(y)𝑔𝑦g(y)italic_g ( italic_y ) for all y𝒴𝑦𝒴y\in\mathcal{Y}italic_y ∈ caligraphic_Y (Xiong etal., 2020; Monath etal., 2023).

Initialization  We initialize the parameters of the dual encoders, ΘΘ\Thetaroman_Θ, using pre-trained models, such as pre-trained language models, T5 and GTR (Devlin etal., 2019; Raffel etal., 2020; Ni etal., 2022).

Stale Cached Representations When we are training the parameters, ΘΘ\Thetaroman_Θ, each target’s vector according to g(y;Θ)𝑔𝑦Θg(y;\Theta)italic_g ( italic_y ; roman_Θ ) changes at each step of training. Therefore, a commonly used approach is to define an approximation, g(y)superscript𝑔𝑦g^{\prime}(y)italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ), that is a lookup for a “stale” cached embedding for the given target. The stale embedding comes from running the target encoder at a particular time step t𝑡titalic_t, of training, and caching the result, i.e., g(y,Θt)𝑔𝑦subscriptΘ𝑡g(y,\Theta_{t})italic_g ( italic_y , roman_Θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) in a buffer, B|𝒴|×D𝐵superscript𝒴𝐷B\in\mathcal{R}^{|\mathcal{Y}|\times D}italic_B ∈ caligraphic_R start_POSTSUPERSCRIPT | caligraphic_Y | × italic_D end_POSTSUPERSCRIPT, i.e. g(yi)Byisuperscript𝑔subscript𝑦𝑖subscript𝐵subscript𝑦𝑖g^{\prime}(y_{i})\triangleq B_{y_{i}}italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ≜ italic_B start_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT.To find the top-k𝑘kitalic_k targets for input x𝑥xitalic_x we compute approximate logits BTf(x)|𝒴|superscript𝐵𝑇𝑓𝑥superscript𝒴B^{T}f(x)\in\mathcal{R}^{|\mathcal{Y}|}italic_B start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_f ( italic_x ) ∈ caligraphic_R start_POSTSUPERSCRIPT | caligraphic_Y | end_POSTSUPERSCRIPT and select the top-k𝑘kitalic_k targets to define S(𝒴)𝑆𝒴S(\mathcal{Y})italic_S ( caligraphic_Y ). Even before training, we can use the pre-trained model to produce embeddings for all targets Byi=g(yi;Θ0)subscript𝐵subscript𝑦𝑖𝑔subscript𝑦𝑖subscriptΘ0B_{y_{i}}=g(y_{i};\Theta_{0})italic_B start_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT = italic_g ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; roman_Θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). While B𝐵Bitalic_B may seem large, this is considerably more efficient than exact computation and is possible, on accelerators, for |𝒴|𝒴|\mathcal{Y}|| caligraphic_Y | in the tens of millions.

The bias of this approach (and subsequent degradation in performance) depends on the staleness or drift of the embeddings, i.e., Byig(yi)normsubscript𝐵subscript𝑦𝑖𝑔subscript𝑦𝑖||B_{y_{i}}-g(y_{i})||| | italic_B start_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_g ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) | | which will increase as we update the parameters of g(y)𝑔𝑦g(y)italic_g ( italic_y ). This can be mitigated by recomputing B𝐵Bitalic_B periodically throughout training (at notable cost). This approach of periodically recomputing has been used (Guu etal., 2020; Izacard etal., 2022; Monath etal., 2023), but there is still much room for improvement.

3 Improving Training with Target Correctors

Our proposed approach builds upon these stale buffer approximations by using an additional parametric model. The additional model aims to improve upon the stale g(y)superscript𝑔𝑦g^{\prime}(y)italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) to yield a better approximation of g(y)𝑔𝑦g(y)italic_g ( italic_y ).

We refer to this additional parametric model as a target corrector network, h(;Ψ)Ψh(\cdot;\Psi)italic_h ( ⋅ ; roman_Ψ ) or simply h()h(\cdot)italic_h ( ⋅ ) when the parameters ΨΨ\Psiroman_Ψ are not pertinent. This target corrector network takes as input the existing stale vector embedding, g(y)superscript𝑔𝑦g^{\prime}(y)italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ), and yields the following approximation of the softmax function:

Ph(y|x)exp(βf(x),hg(y)).proportional-tosubscript𝑃conditional𝑦𝑥𝛽𝑓𝑥superscript𝑔𝑦P_{h}(y|x)\propto\exp(\beta\langle f(x),h\circ g^{\prime}(y)\rangle).\vspace{-%3mm}italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_y | italic_x ) ∝ roman_exp ( italic_β ⟨ italic_f ( italic_x ) , italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) ⟩ ) .(4)

With significantly fewer parameters than a typical dual-encoder, i.e., |Ψ||Θ|much-less-thanΨΘ|\Psi|\ll|\Theta|| roman_Ψ | ≪ | roman_Θ |, this small parameteric model is efficient enough to provide approximately fresh representations of every target at every training step. The target corrector network presents interesting research questions regarding whether the network can obviate the need for re-embedding, what kinds of staleness or drift can be effectively modeled, and how much training data is required.

Improving Dense Retriever Training with Corrector Networks (1)

Warmup: Training the corrector network in isolation We begin by considering how we would train only the parameters of the target corrector network, independently of the dual-encoders f(x)𝑓𝑥f(x)italic_f ( italic_x ) and g(y)𝑔𝑦g(y)italic_g ( italic_y ). Afterwards, we present an algorithm for jointly training the target corrector network and the dual-encoders.To train the parameters ΨΨ\Psiroman_Ψ of the corrector network, h(;Ψ)Ψh(\cdot;\Psi)italic_h ( ⋅ ; roman_Ψ ), we collect training examples of input data points xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, the exact target embeddings g(y)𝑔𝑦g(y)italic_g ( italic_y ), and stale embeddings g(y)superscript𝑔𝑦g^{\prime}(y)italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) for a subset of targets S(𝒴)i𝒴𝑆subscript𝒴𝑖𝒴S(\mathcal{Y})_{i}\subset\mathcal{Y}italic_S ( caligraphic_Y ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⊂ caligraphic_Y, i.e., {(f(xi),g(yb),g(yb))|ybS(𝒴)i}conditional-set𝑓subscript𝑥𝑖𝑔subscript𝑦𝑏superscript𝑔subscript𝑦𝑏subscript𝑦𝑏𝑆subscript𝒴𝑖\{(f(x_{i}),g(y_{b}),g^{\prime}(y_{b}))\ |\ y_{b}\in S(\mathcal{Y})_{i}\}{ ( italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_g ( italic_y start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ) , italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ) ) | italic_y start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ∈ italic_S ( caligraphic_Y ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }.

We consider two loss functions for training hhitalic_h: the mean-squared error between representations given by g(y)𝑔𝑦g(y)italic_g ( italic_y ) and the corrected representations hg(y)superscript𝑔𝑦h\circ g^{\prime}(y)italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) (Eq.5) and the cross entropy loss between the truncated softmax using g(y)𝑔𝑦g(y)italic_g ( italic_y ) and truncated softmax using hg(y)superscript𝑔𝑦h\circ g^{\prime}(y)italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) (Eq.6):

MSE(yi)subscript𝑀𝑆𝐸subscript𝑦𝑖\displaystyle\ell_{MSE}(y_{i})roman_ℓ start_POSTSUBSCRIPT italic_M italic_S italic_E end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )=g(yi)hg(yi;Ψ)2absentsubscriptnorm𝑔subscript𝑦𝑖superscript𝑔subscript𝑦𝑖Ψ2\displaystyle=||g(y_{i})-h\circ g^{\prime}(y_{i};\Psi)||_{2}= | | italic_g ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; roman_Ψ ) | | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT(5)
d(yi)subscriptdsubscript𝑦𝑖\displaystyle\ell_{{\mathrm{d}}}(y_{i})roman_ℓ start_POSTSUBSCRIPT roman_d end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )=logP~(y|x)logP~h(y|x)absent~𝑃conditional𝑦𝑥subscript~𝑃conditional𝑦𝑥\displaystyle=\log\tilde{P}(y|x)-\log\tilde{P}_{h}(y|x)= roman_log over~ start_ARG italic_P end_ARG ( italic_y | italic_x ) - roman_log over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_y | italic_x )(6)
P~h(y|x)subscript~𝑃conditional𝑦𝑥\displaystyle\tilde{P}_{h}(y|x)over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_y | italic_x )=exp(βf(x),hg(y;Ψ))yS(𝒴)iexp(βf(x),hg(y;Ψ)).absent𝛽𝑓𝑥superscript𝑔𝑦Ψsubscriptsuperscript𝑦𝑆subscript𝒴𝑖𝛽𝑓𝑥superscript𝑔superscript𝑦Ψ\displaystyle=\frac{\exp(\beta\langle f(x),h\circ g^{\prime}(y;\Psi)\rangle)}{%\sum_{{y^{\prime}\in S(\mathcal{Y})_{i}}}\exp(\beta\langle f(x),h\circ g^{%\prime}(y^{\prime};\Psi)\rangle)}.= divide start_ARG roman_exp ( italic_β ⟨ italic_f ( italic_x ) , italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ; roman_Ψ ) ⟩ ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_S ( caligraphic_Y ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_exp ( italic_β ⟨ italic_f ( italic_x ) , italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ; roman_Ψ ) ⟩ ) end_ARG .

where P~(y|x)~𝑃conditional𝑦𝑥\tilde{P}(y|x)over~ start_ARG italic_P end_ARG ( italic_y | italic_x ) is the truncated softmax g(y)𝑔𝑦g(y)italic_g ( italic_y ) (Eq.3). The mean-squared error loss directly tries to match the target encoder model’s embeddings.The cross-entropy loss down-weights the importance of targets y𝑦yitalic_y which do not contribute substantial probability to P(y|x)𝑃conditional𝑦𝑥P(y|x)italic_P ( italic_y | italic_x ) and allows for greater use of model capacity.The parameters of the target corrector networks are optimized using gradient descent.Empirically, we find the cross-entropy objective to perform slightly better (Table1) and focus the presentation on cross-entropy.

Jointly Training Corrector Networks & Dual-Encoders We present a method (Algorithm1) for simultaneously training dual-encoders for a given task (e.g., retrieval or equivalently large output-space classification)and the target corrector network. The training algorithm will optimize both theparameters of the target corrector network and additionally use the corrector network to approximate the softmax. Each step consists of: (1) using the corrector network to provide an approximately updated representation of every target, (2) picking a subset of targets for the truncated softmax using the output of the corrector network, (3) computing a task loss for the dual-encoder models and loss for the corrector networks, (4) updating, according to their respective losses, the parameters for both the dual-encoders and the corrector networks using gradient descent.

In more detail, we are given task training data, X={(x1,y1),,(xm,ym)}𝑋subscript𝑥1subscript𝑦1subscript𝑥𝑚subscript𝑦𝑚X=\{(x_{1},y_{1}),\dots,(x_{m},y_{m})\}italic_X = { ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , ( italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) }. We are given a task loss function \mathcal{L}caligraphic_L and a corrector network loss \ellroman_ℓ. The dual-encoder models are f(x),g(y)𝑓𝑥𝑔𝑦f(x),g(y)italic_f ( italic_x ) , italic_g ( italic_y ) and their initial parameters are Θ0subscriptΘ0\Theta_{0}roman_Θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT.Prior to the first training step, we instantiate a buffer of the targets’ representations, By=g(y)=g(y;Θ0)subscript𝐵𝑦superscript𝑔𝑦𝑔𝑦subscriptΘ0B_{y}=g^{\prime}(y)=g(y;\Theta_{0})italic_B start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT = italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) = italic_g ( italic_y ; roman_Θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). We will avert the need for the expensive updating of the buffer by re-embedding targets with the target encoder.In each step, we sample a training point and label pair xi,yisubscript𝑥𝑖subscript𝑦𝑖x_{i},y_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT from X𝑋Xitalic_X. We apply the target corrector network to all of the stale representations in the buffer to obtain hg(y)y𝒴superscript𝑔𝑦for-all𝑦𝒴h\circ g^{\prime}(y)\ \forall y\in\mathcal{Y}italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) ∀ italic_y ∈ caligraphic_Y. This computation does not require running a dual-encoder; we use the cached buffer representation of each target as input to the corrector network. The corrector network is typically a two-layer MLP and hence efficient enough to be used in this way. With these representations from h()h(\cdot)italic_h ( ⋅ ), we sample (or select exact top-k𝑘kitalic_k) targets according to Ph(y|x)subscript𝑃conditional𝑦𝑥P_{h}(y|x)italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_y | italic_x ) (Eq.4) to form a subset of targets Sxi(𝒴)subscript𝑆subscript𝑥𝑖𝒴S_{x_{i}}(\mathcal{Y})italic_S start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_Y ) for the truncated softmax.

Given this subset, we compute the task and correction losses and update their respective model parameters. First, we compute the task loss, which is cross-entropy. The task loss will only be used to update the parameters of the dual-encoders, ΘΘ\Thetaroman_Θ, not the parameters of the target corrector network. We compute the truncated softmax P~(y|x)exp(βf(x),g(y))proportional-to~𝑃conditional𝑦𝑥𝛽𝑓𝑥𝑔𝑦\tilde{P}(y|x)\propto\exp(\beta\langle f(x),g(y)\rangle)over~ start_ARG italic_P end_ARG ( italic_y | italic_x ) ∝ roman_exp ( italic_β ⟨ italic_f ( italic_x ) , italic_g ( italic_y ) ⟩ ) (Equation3). We define a one-hot Psuperscript𝑃P^{\star}italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT according to the training data label yisubscript𝑦𝑖y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. We compute the task specific loss \mathcal{L}caligraphic_L as a function of P~~𝑃\tilde{P}over~ start_ARG italic_P end_ARG and Psuperscript𝑃P^{\star}italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, and update the dual encoder parameters via gradient descent ΘΘηΘΘΘ𝜂subscriptΘ\Theta\leftarrow\Theta-\eta\nabla_{\Theta}\mathcal{L}roman_Θ ← roman_Θ - italic_η ∇ start_POSTSUBSCRIPT roman_Θ end_POSTSUBSCRIPT caligraphic_L.

Next, we will use the same sample of targets Sxi(𝒴)subscript𝑆subscript𝑥𝑖𝒴S_{x_{i}}(\mathcal{Y})italic_S start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_Y ) to compute the target corrector network’s loss and parameter update. Importantly, this will only update the parameters of the target corrector network, ΨΨ\Psiroman_Ψ, not the parameters of the dual-encoders. Here we describe the use of the cross-entropy loss. However, an analogous update procedure could be used for other loss functions. We compute the truncated softmax according to the target corrector network’s output: P~h(y|x)exp(βf(x),hg(y))proportional-tosubscript~𝑃conditional𝑦𝑥𝛽𝑓𝑥superscript𝑔𝑦\tilde{P}_{h}(y|x)\propto\exp(\beta\langle f(x),h\circ g^{\prime}(y)\rangle)over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_y | italic_x ) ∝ roman_exp ( italic_β ⟨ italic_f ( italic_x ) , italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) ⟩ ). We then compute the target corrector network loss, \ellroman_ℓ, cross-entropy, which tries to align two truncated distributions P~hsubscript~𝑃\tilde{P}_{h}over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT and P~~𝑃\tilde{P}over~ start_ARG italic_P end_ARG. The target corrector network’s parameters are updated by gradient descent ΨΨηΨΨΨ𝜂subscriptΨ\Psi\leftarrow\Psi-\eta\nabla_{\Psi}\ellroman_Ψ ← roman_Ψ - italic_η ∇ start_POSTSUBSCRIPT roman_Ψ end_POSTSUBSCRIPT roman_ℓ.

Training the target corrector network, which has only a small number of parameters, is much less computationally intensive to train than the dual-encoder model. Furthermore, we are given “for free” the representations g(y)𝑔𝑦g(y)italic_g ( italic_y ) since they are used to compute P~~𝑃\tilde{P}over~ start_ARG italic_P end_ARG for the task loss. These representations can then easily be re-used for training the corrector.

The training procedure is summarized in Algorithm1. At prediction time, the corrector network is not used, instead the trained dual-encoder g(y,Θ)𝑔𝑦Θg(y,\Theta)italic_g ( italic_y , roman_Θ ) is used.

Data: Training data X𝑋Xitalic_X, Targets 𝒴𝒴\mathcal{Y}caligraphic_Y, Input encoder f()𝑓f(\cdot)italic_f ( ⋅ ), Target encoder g()𝑔g(\cdot)italic_g ( ⋅ ), Approximate target encoder g()superscript𝑔g^{\prime}(\cdot)italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( ⋅ ) (buffer B𝐵Bitalic_B), target corrector network h()h(\cdot)italic_h ( ⋅ ), temperature β𝛽\betaitalic_β, task loss \mathcal{L}caligraphic_L, target corrector network loss \ellroman_ℓ, learning rate η𝜂\etaitalic_η, number of truncated samples k𝑘kitalic_k

whileTrainingdo

  Sample training data (xi,yi)Xsimilar-tosubscript𝑥𝑖subscript𝑦𝑖𝑋(x_{i},y_{i})\sim X( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∼ italic_X
Compute hg(y)superscript𝑔𝑦h\circ g^{\prime}(y)italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) for all y𝒴𝑦𝒴y\in\mathcal{Y}italic_y ∈ caligraphic_Y using the buffer B𝐵Bitalic_B
Set Sxi(𝒴)subscript𝑆subscript𝑥𝑖𝒴S_{x_{i}}(\mathcal{Y})italic_S start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_Y ) using exp(βf(xi)Thg(y))𝛽𝑓superscriptsubscript𝑥𝑖𝑇superscript𝑔𝑦\exp(\beta f(x_{i})^{T}h\circ g^{\prime}(y))roman_exp ( italic_β italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) ) via top-k𝑘kitalic_k
Include supervised label Sxi(𝒴)Sxi(𝒴){yi}subscript𝑆subscript𝑥𝑖𝒴subscript𝑆subscript𝑥𝑖𝒴subscript𝑦𝑖S_{x_{i}}(\mathcal{Y})\leftarrow S_{x_{i}}(\mathcal{Y})\cup\{y_{i}\}italic_S start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_Y ) ← italic_S start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_Y ) ∪ { italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }
Define P~(y|xi)=exp(βf(xi)Tg(y))ySxi(𝒴)exp(βf(xi)Tg(y))~𝑃conditional𝑦subscript𝑥𝑖𝛽𝑓superscriptsubscript𝑥𝑖𝑇𝑔𝑦subscriptsuperscript𝑦subscript𝑆subscript𝑥𝑖𝒴𝛽𝑓superscriptsubscript𝑥𝑖𝑇𝑔superscript𝑦\tilde{P}(y|x_{i})=\frac{\exp(\beta f(x_{i})^{T}g(y))}{\sum_{y^{\prime}\in S_{%x_{i}}(\mathcal{Y})}\exp(\beta f(x_{i})^{T}g(y^{\prime}))}over~ start_ARG italic_P end_ARG ( italic_y | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = divide start_ARG roman_exp ( italic_β italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_g ( italic_y ) ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_S start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_Y ) end_POSTSUBSCRIPT roman_exp ( italic_β italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_g ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) end_ARG
Define P~h(y|xi)=exp(βf(xi)Thg(y))ySxi(𝒴)exp(βf(xi)Thg(y))subscript~𝑃conditional𝑦subscript𝑥𝑖𝛽𝑓superscriptsubscript𝑥𝑖𝑇superscript𝑔𝑦subscriptsuperscript𝑦subscript𝑆subscript𝑥𝑖𝒴𝛽𝑓superscriptsubscript𝑥𝑖𝑇superscript𝑔superscript𝑦{\tilde{P}_{h}(y|x_{i})=\frac{\exp(\beta f(x_{i})^{T}h\circ g^{\prime}(y))}{%\sum_{y^{\prime}\in S_{x_{i}}(\mathcal{Y})}\exp(\beta f(x_{i})^{T}h\circ g^{%\prime}(y^{\prime}))}}over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_y | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = divide start_ARG roman_exp ( italic_β italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_S start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_Y ) end_POSTSUBSCRIPT roman_exp ( italic_β italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) end_ARG
Define Psuperscript𝑃P^{\star}italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT to be a one-hot vector for yisubscript𝑦𝑖y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT.
Compute task loss \mathcal{L}caligraphic_L using P~~𝑃\tilde{P}over~ start_ARG italic_P end_ARG and Psuperscript𝑃P^{\star}italic_P start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT
Compute correction loss \mathcal{\ell}roman_ℓ using P~~𝑃\tilde{P}over~ start_ARG italic_P end_ARG and P~hsubscript~𝑃\tilde{P}_{h}over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT
Update dual-encoder parameters ΘΘηΘΘΘ𝜂subscriptΘ\Theta\leftarrow\Theta-\eta\nabla_{\Theta}\mathcal{L}roman_Θ ← roman_Θ - italic_η ∇ start_POSTSUBSCRIPT roman_Θ end_POSTSUBSCRIPT caligraphic_L
Update corrector network parameters ΨΨηΨΨΨ𝜂subscriptΨ\Psi\leftarrow\Psi-\eta\nabla_{\Psi}\ellroman_Ψ ← roman_Ψ - italic_η ∇ start_POSTSUBSCRIPT roman_Ψ end_POSTSUBSCRIPT roman_ℓ

end while

3.1 Latent Variables in Retrieval Augmented Models

Retrieval augmented language models (RLMs) typically consist of two major architectural components, a retriever model (e.g., a dual-encoder) and a generative language model or reader model(Guu etal., 2020; Izacard & Grave, 2021; Izacard etal., 2022).The input to a retrieval augmented language model is a natural language text sequence, x𝑥xitalic_x.This input text will be encoded using a dual-encoder retrieval model, f(x)𝑓𝑥f(x)italic_f ( italic_x ). Retrieval will be performed over a corpus of targets, 𝒴𝒴\mathcal{Y}caligraphic_Y, returning k𝑘kitalic_k targets relevant to x𝑥xitalic_x, denoted Sx(𝒴)subscript𝑆𝑥𝒴S_{x}(\mathcal{Y})italic_S start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( caligraphic_Y ). The reader model takes as input the retrieved targets, Sx(𝒴)subscript𝑆𝑥𝒴S_{x}(\mathcal{Y})italic_S start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( caligraphic_Y ), and the text x𝑥xitalic_x, and generates text.

Concretely, in our experiments, the input text x𝑥xitalic_x is a question. The retrieval corpus contains targets y𝑦yitalic_y corresponding to passages in Wikipedia. The reader model takes as input the question and retrieved passages and generates a short answer to the question. We present the remainder of the section with this question-answering task in mind.

RLMs can be formalized as latent variable models. The softmax function is used to parameterize the distribution over a discrete latent variable, which corresponds tothe retrieved targets. We use a𝑎aitalic_a to refer to the generated sequence of text, i.e., the generated answer:

P(a|x)=ySx(𝒴)P(a|y,x)P(y|x).𝑃conditional𝑎𝑥subscript𝑦subscript𝑆𝑥𝒴𝑃conditional𝑎𝑦𝑥𝑃conditional𝑦𝑥\vspace{-1mm}P(a|x)=\sum_{y\in S_{x}(\mathcal{Y})}P(a|y,x)P(y|x).\vspace{-1mm}italic_P ( italic_a | italic_x ) = ∑ start_POSTSUBSCRIPT italic_y ∈ italic_S start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( caligraphic_Y ) end_POSTSUBSCRIPT italic_P ( italic_a | italic_y , italic_x ) italic_P ( italic_y | italic_x ) .(7)

P(a|y,x)𝑃conditional𝑎𝑦𝑥P(a|y,x)italic_P ( italic_a | italic_y , italic_x ) is an autoregressive language model. P(y|x)𝑃conditional𝑦𝑥P(y|x)italic_P ( italic_y | italic_x ) is computed by the softmax with logits from Equation2 using the encoder models f(x)𝑓𝑥f(x)italic_f ( italic_x ) and g(y)𝑔𝑦g(y)italic_g ( italic_y ).

When training RLMs, we receive supervision in the form of question, answer pairs, e.g., xi,aiXsimilar-tosubscript𝑥𝑖subscript𝑎𝑖𝑋x_{i},a_{i}\sim Xitalic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ italic_X. We do not receive supervision on which targets Sx(𝒴)subscript𝑆𝑥𝒴S_{x}(\mathcal{Y})italic_S start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( caligraphic_Y ) should be retrieved. We will learn the parameters of both the reader model and retriever model using these supervised question/answer pairs.

To train the reader and retriever model, we use perplexity distillation (Izacard etal., 2022) for retriever loss and negative log-likelihood for the reader loss. Perplexity distillation is computed as the cross-entropy between two truncated distributions, one being the retriever’s P~(y|x)~𝑃conditional𝑦𝑥\tilde{P}(y|x)over~ start_ARG italic_P end_ARG ( italic_y | italic_x ) (Equation3) and the other using the reader model to provide a soft-relevance label for each target in Sx(𝒴)subscript𝑆𝑥𝒴S_{x}(\mathcal{Y})italic_S start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( caligraphic_Y ):

Pa(y|x)=P(a|y,x)ySx(𝒴)P(a|y,x).subscript𝑃𝑎conditional𝑦𝑥𝑃conditional𝑎𝑦𝑥subscriptsuperscript𝑦subscript𝑆𝑥𝒴𝑃conditional𝑎superscript𝑦𝑥\displaystyle P_{a}(y|x)=\frac{P(a|y,x)}{\sum_{y^{\prime}\in S_{x}(\mathcal{Y}%)}P(a|y^{\prime},x)}.italic_P start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ( italic_y | italic_x ) = divide start_ARG italic_P ( italic_a | italic_y , italic_x ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_S start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( caligraphic_Y ) end_POSTSUBSCRIPT italic_P ( italic_a | italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_x ) end_ARG .(8)

In words, Pa(y|x)subscript𝑃𝑎conditional𝑦𝑥P_{a}(y|x)italic_P start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ( italic_y | italic_x ) normalizes the likelihood scores of the reader model generating the correct answer text when conditioned on the given retrieved target y𝑦yitalic_y. The reader’s loss function, negative-log likelihood is simply computed using the supervised answer text. The two losses are averaged and parameters optimized with gradient descent.

To facilitate efficient training, we use our proposed target corrector network to select the subset of retrieved targets Sx(𝒴)subscript𝑆𝑥𝒴S_{x}(\mathcal{Y})italic_S start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( caligraphic_Y ) used at training time. This is done in the same way as in Algorithm1, i.e., we pick a subset of k𝑘kitalic_k targets Sx(𝒴)subscript𝑆𝑥𝒴S_{x}(\mathcal{Y})italic_S start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( caligraphic_Y ) for x𝑥xitalic_x according to exp(βf(x)Thg(y))𝛽𝑓superscript𝑥𝑇superscript𝑔𝑦\exp(\beta f(x)^{T}h\circ g^{\prime}(y))roman_exp ( italic_β italic_f ( italic_x ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) ) via top-k𝑘kitalic_k or Gumbel-Max sampling. We can make simple modifications to Algorithm1, which are presented in Algorithm2 to train the RLM. We compute two task-specific losses (perplexity distillation, negative log-likelihood) and optimize both the reader and retriever parameters. We use cross-entropy to train the corrector, which is again only used at training time. At prediction time, the trained retriever model is used.

Data: Training data X𝑋Xitalic_X, Targets 𝒴𝒴\mathcal{Y}caligraphic_Y, Retriever and Reader Parameters ΘΘ\Thetaroman_Θ, Correction Model parameters ΨΨ\Psiroman_Ψ, Input encoder f()𝑓f(\cdot)italic_f ( ⋅ ), Target encoder g()𝑔g(\cdot)italic_g ( ⋅ ), Approximate target encoder g()superscript𝑔g^{\prime}(\cdot)italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( ⋅ ) (buffer B𝐵Bitalic_B), corrector model h()h(\cdot)italic_h ( ⋅ ), temperature β𝛽\betaitalic_β, retriever loss \mathcal{L}caligraphic_L, reader loss superscript\mathcal{L}^{\prime}caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, corrector model loss \ellroman_ℓ, learning rate η𝜂\etaitalic_η, number of truncated samples k𝑘kitalic_k

whileTrainingdo

  Sample training data (xi,a)Xsimilar-tosubscript𝑥𝑖𝑎𝑋(x_{i},a)\sim X( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_a ) ∼ italic_X
Compute hg(y)superscript𝑔𝑦h\circ g^{\prime}(y)italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) for all y𝒴𝑦𝒴y\in\mathcal{Y}italic_y ∈ caligraphic_Y using the buffer B𝐵Bitalic_B
Set Sxi(𝒴)subscript𝑆subscript𝑥𝑖𝒴S_{x_{i}}(\mathcal{Y})italic_S start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_Y ) using exp(βf(xi)Thg(y))𝛽𝑓superscriptsubscript𝑥𝑖𝑇superscript𝑔𝑦\exp(\beta f(x_{i})^{T}h\circ g^{\prime}(y))roman_exp ( italic_β italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) ) via top-k𝑘kitalic_k
Define P~(y|xi)=exp(βf(xi)Tg(y))ySxi(𝒴)exp(βf(xi)Tg(y))~𝑃conditional𝑦subscript𝑥𝑖𝛽𝑓superscriptsubscript𝑥𝑖𝑇𝑔𝑦subscriptsuperscript𝑦subscript𝑆subscript𝑥𝑖𝒴𝛽𝑓superscriptsubscript𝑥𝑖𝑇𝑔superscript𝑦\tilde{P}(y|x_{i})=\frac{\exp(\beta f(x_{i})^{T}g(y))}{\sum_{y^{\prime}\in S_{%x_{i}}(\mathcal{Y})}\exp(\beta f(x_{i})^{T}g(y^{\prime}))}over~ start_ARG italic_P end_ARG ( italic_y | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = divide start_ARG roman_exp ( italic_β italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_g ( italic_y ) ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_S start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_Y ) end_POSTSUBSCRIPT roman_exp ( italic_β italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_g ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) end_ARG
Define P~h(y|xi)=exp(βf(xi)Thg(y))ySxi(𝒴)exp(βf(xi)Thg(y))subscript~𝑃conditional𝑦subscript𝑥𝑖𝛽𝑓superscriptsubscript𝑥𝑖𝑇superscript𝑔𝑦subscriptsuperscript𝑦subscript𝑆subscript𝑥𝑖𝒴𝛽𝑓superscriptsubscript𝑥𝑖𝑇superscript𝑔superscript𝑦{\tilde{P}_{h}(y|x_{i})=\frac{\exp(\beta f(x_{i})^{T}h\circ g^{\prime}(y))}{%\sum_{y^{\prime}\in S_{x_{i}}(\mathcal{Y})}\exp(\beta f(x_{i})^{T}h\circ g^{%\prime}(y^{\prime}))}}over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_y | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = divide start_ARG roman_exp ( italic_β italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_S start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_Y ) end_POSTSUBSCRIPT roman_exp ( italic_β italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) end_ARG
Define Pa(y|x)=P(a|y,x)ySx(𝒴)P(a|y,x).subscript𝑃𝑎conditional𝑦𝑥𝑃conditional𝑎𝑦𝑥subscriptsuperscript𝑦subscript𝑆𝑥𝒴𝑃conditional𝑎superscript𝑦𝑥P_{a}(y|x)=\frac{P(a|y,x)}{\sum_{y^{\prime}\in S_{x}(\mathcal{Y})}P(a|y^{%\prime},x)}.italic_P start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ( italic_y | italic_x ) = divide start_ARG italic_P ( italic_a | italic_y , italic_x ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ italic_S start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( caligraphic_Y ) end_POSTSUBSCRIPT italic_P ( italic_a | italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_x ) end_ARG .
Define PLM(a|x)=ySx(𝒴)P(a|y,x)P(y|x).subscript𝑃LMconditional𝑎𝑥subscript𝑦subscript𝑆𝑥𝒴𝑃conditional𝑎𝑦𝑥𝑃conditional𝑦𝑥P_{\text{LM}}(a|x)=\sum_{y\in S_{x}(\mathcal{Y})}P(a|y,x)P(y|x).italic_P start_POSTSUBSCRIPT LM end_POSTSUBSCRIPT ( italic_a | italic_x ) = ∑ start_POSTSUBSCRIPT italic_y ∈ italic_S start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ( caligraphic_Y ) end_POSTSUBSCRIPT italic_P ( italic_a | italic_y , italic_x ) italic_P ( italic_y | italic_x ) .
Compute reader loss superscript\mathcal{L}^{\prime}caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT using PLM(a|x)subscript𝑃LMconditional𝑎𝑥P_{\text{LM}}(a|x)italic_P start_POSTSUBSCRIPT LM end_POSTSUBSCRIPT ( italic_a | italic_x )
Compute retriever loss \mathcal{L}caligraphic_L using P~(y|xi)~𝑃conditional𝑦subscript𝑥𝑖\tilde{P}(y|x_{i})over~ start_ARG italic_P end_ARG ( italic_y | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) and Pa(y|x)subscript𝑃𝑎conditional𝑦𝑥P_{a}(y|x)italic_P start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ( italic_y | italic_x )
Compute correction loss \mathcal{\ell}roman_ℓ using P~~𝑃\tilde{P}over~ start_ARG italic_P end_ARG and P~hsubscript~𝑃\tilde{P}_{h}over~ start_ARG italic_P end_ARG start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT
Update retriever & reader params ΘΘηΘ+2ΘΘ𝜂subscriptΘsuperscript2\Theta\leftarrow\Theta-\eta\nabla_{\Theta}\frac{\mathcal{L}+\mathcal{L}^{%\prime}}{2}roman_Θ ← roman_Θ - italic_η ∇ start_POSTSUBSCRIPT roman_Θ end_POSTSUBSCRIPT divide start_ARG caligraphic_L + caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG
Update corrector network params ΨΨηΨΨΨ𝜂subscriptΨ\Psi\leftarrow\Psi-\eta\nabla_{\Psi}\ellroman_Ψ ← roman_Ψ - italic_η ∇ start_POSTSUBSCRIPT roman_Ψ end_POSTSUBSCRIPT roman_ℓ

end while

4 Analysis

We will explore the generalization of the proposed target corrector network in terms of unseen targets for a particular input data point, and will show the relationship between generalization error, the complexity of the target corrector network hhitalic_h, and the discrepancy of the stale representations, gsuperscript𝑔g^{\prime}italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, and true representations g𝑔gitalic_g. All proofs are in AppendixA.

Let :×:\ell:\mathbb{R}\times\mathbb{R}\to\mathbb{R}roman_ℓ : blackboard_R × blackboard_R → blackboard_R is a loss function for the target corrector network (Eq.5 & 6).

For any point x𝑥xitalic_x, consider the distribution given by the softmax using stale approximation gsuperscript𝑔g^{\prime}italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT:

𝒟~YPg(y|x)=exp(βf(x),g(y))y𝒴exp(βf(x),g(y)),subscript~𝒟𝑌subscript𝑃superscript𝑔conditional𝑦𝑥𝛽𝑓𝑥superscript𝑔𝑦subscriptsuperscript𝑦𝒴𝛽𝑓𝑥superscript𝑔superscript𝑦\tilde{\mathscr{D}}_{Y}\triangleq{P}_{g^{\prime}}(y|x)=\frac{\exp(\beta\langlef%(x),g^{\prime}(y)\rangle)}{\sum_{y^{\prime}\in\mathcal{Y}}\exp(\beta\langle f(%x),g^{\prime}(y^{\prime})\rangle)},over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ≜ italic_P start_POSTSUBSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_y | italic_x ) = divide start_ARG roman_exp ( italic_β ⟨ italic_f ( italic_x ) , italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) ⟩ ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ caligraphic_Y end_POSTSUBSCRIPT roman_exp ( italic_β ⟨ italic_f ( italic_x ) , italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ⟩ ) end_ARG ,(9)

and similarly define 𝒟YPg(y|x)subscript𝒟𝑌subscript𝑃𝑔conditional𝑦𝑥{\mathscr{D}}_{Y}\triangleq{P}_{g}(y|x)script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ≜ italic_P start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( italic_y | italic_x ) as the true distribution, using g𝑔gitalic_g (Eq.1).

We begin by defining three kinds of risk.

Empirical Risk  On a set of n𝑛nitalic_n-targets 𝒮~n={y1,,yn}subscript~𝒮𝑛subscript𝑦1subscript𝑦𝑛\tilde{\mathscr{S}}_{n}=\{y_{1},...,y_{n}\}over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = { italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_y start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } sampled from 𝒟~Ysubscript~𝒟𝑌\tilde{\mathscr{D}}_{Y}over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT, we minimize the empirical risk:

R,ϕ(𝒮~n)=1ni=1n(ϕ(yi),g(yi)),subscript𝑅italic-ϕsubscript~𝒮𝑛1𝑛superscriptsubscript𝑖1𝑛italic-ϕsubscript𝑦𝑖𝑔subscript𝑦𝑖R_{\ell,{\phi}}(\tilde{\mathscr{S}}_{n})=\frac{1}{n}\sum_{i=1}^{n}\ell(\phi(y_%{i}),g(y_{i})),italic_R start_POSTSUBSCRIPT roman_ℓ , italic_ϕ end_POSTSUBSCRIPT ( over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_ℓ ( italic_ϕ ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_g ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ,(10)

over a function class ϕitalic-ϕ\phi\in\mathcal{F}italic_ϕ ∈ caligraphic_F.

True Population Risk  For generalization error, we are interested in how large the true population risk can become over a function class ϕitalic-ϕ\phi\in\mathcal{F}italic_ϕ ∈ caligraphic_F.

R,ϕ(𝒟Y)=𝔼Y𝒟Y[(ϕ(Y),g(Y))],subscript𝑅italic-ϕsubscript𝒟𝑌subscript𝔼similar-to𝑌subscript𝒟𝑌delimited-[]italic-ϕ𝑌𝑔𝑌R_{\ell,\phi}(\mathscr{D}_{Y})=\mathbb{E}_{Y\sim\mathscr{D}_{Y}}[\ell(\phi(Y),%g(Y))],italic_R start_POSTSUBSCRIPT roman_ℓ , italic_ϕ end_POSTSUBSCRIPT ( script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ) = blackboard_E start_POSTSUBSCRIPT italic_Y ∼ script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_ℓ ( italic_ϕ ( italic_Y ) , italic_g ( italic_Y ) ) ] ,(11)

We consider the above quantity because we want to ensure good alignment between g(y)𝑔𝑦g(y)italic_g ( italic_y ) and ϕ(y)italic-ϕ𝑦\phi(y)italic_ϕ ( italic_y ) where there is non-trivial probability mass under the true distribution.

Stale Population Risk  The stale population risk is defined analogously to true population risk with 𝒟Y~~subscript𝒟𝑌\tilde{\mathscr{D}_{Y}}over~ start_ARG script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT end_ARG as the distribution, over a function class ϕitalic-ϕ\phi\in\mathcal{F}italic_ϕ ∈ caligraphic_F:

R,ϕ(𝒟~Y)=𝔼Y𝒟~Y[(ϕ(Y),g(Y))].subscript𝑅italic-ϕsubscript~𝒟𝑌subscript𝔼similar-to𝑌subscript~𝒟𝑌delimited-[]italic-ϕ𝑌𝑔𝑌R_{\ell,\phi}(\tilde{\mathscr{D}}_{Y})=\mathbb{E}_{Y\sim\tilde{\mathscr{D}}_{Y%}}[\ell(\phi(Y),g(Y))].\vspace{-3mm}italic_R start_POSTSUBSCRIPT roman_ℓ , italic_ϕ end_POSTSUBSCRIPT ( over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ) = blackboard_E start_POSTSUBSCRIPT italic_Y ∼ over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_ℓ ( italic_ϕ ( italic_Y ) , italic_g ( italic_Y ) ) ] .(12)

Function Classes  The function class ϕitalic-ϕ\phi\in\mathcal{F}italic_ϕ ∈ caligraphic_F is large. We will relate this large function class to a restricted class of functions of the form hgsuperscript𝑔h\circ g^{\prime}italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT by leveraging the approximate stale representations, gsuperscript𝑔g^{\prime}italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT.In other words, we restrict \mathcal{F}caligraphic_F to g={hg:h}superscriptsuperscript𝑔conditional-setsuperscript𝑔\mathcal{F}^{g^{\prime}}=\{h\circ g^{\prime}:h\in\mathcal{H}\}caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT = { italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT : italic_h ∈ caligraphic_H } where \mathcal{H}caligraphic_H represents the simpler function class mapping ddsuperscript𝑑superscript𝑑\mathbb{R}^{d}\to\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT which can express the discrepancy between the stale gsuperscript𝑔g^{\prime}italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT and current g𝑔gitalic_g.

First, we provide a bound on the gap between the population risk and stale population risk. We formalize this in the following lemma. For ease of notation in this exposition, we define 𝒢,subscript𝒢\mathcal{G}_{\ell,\mathcal{F}}caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT as the induced function class:𝒢,={y(ϕ(y),g(y)):ϕ}.subscript𝒢conditional-setmaps-to𝑦italic-ϕ𝑦𝑔𝑦italic-ϕ\mathcal{G}_{\ell,\mathcal{F}}=\{y\mapsto\ell(\phi(y),g(y)):\phi\in\mathcal{F}\}.caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT = { italic_y ↦ roman_ℓ ( italic_ϕ ( italic_y ) , italic_g ( italic_y ) ) : italic_ϕ ∈ caligraphic_F } .

Lemma 4.1.

Given a target encoder g𝑔gitalic_g and its stale approximation gsuperscript𝑔g^{\prime}italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, the gap between the true population risk and stale population risk is bounded in the following way:

R,ϕ(𝒟Y)R,ϕ(𝒟~Y)𝒲(𝒟Y,𝒟~Y)gg1subscript𝑅italic-ϕsubscript𝒟𝑌subscript𝑅italic-ϕsubscript~𝒟𝑌𝒲subscript𝒟𝑌subscript~𝒟𝑌subscriptnorm𝑔superscript𝑔1R_{\ell,\phi}(\mathscr{D}_{Y})-R_{\ell,\phi}(\tilde{\mathscr{D}}_{Y})\leq%\mathcal{W}(\mathscr{D}_{Y},\tilde{\mathscr{D}}_{Y})\leq\|g-g^{\prime}\|_{1}italic_R start_POSTSUBSCRIPT roman_ℓ , italic_ϕ end_POSTSUBSCRIPT ( script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ) - italic_R start_POSTSUBSCRIPT roman_ℓ , italic_ϕ end_POSTSUBSCRIPT ( over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ) ≤ caligraphic_W ( script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT , over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ) ≤ ∥ italic_g - italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT(13)

where 𝒲𝒲\mathcal{W}caligraphic_W is the Wasserstein distance. Furthermore, if the approximation gsuperscript𝑔g^{\prime}italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT comes from the same neural model as g𝑔gitalic_g with parameters perturbed by u𝑢uitalic_u as in aforementioned stale approximation, we have: gg1Lusubscriptnorm𝑔superscript𝑔1𝐿norm𝑢\|g-g^{\prime}\|_{1}\leq L\|u\|∥ italic_g - italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ italic_L ∥ italic_u ∥ with L𝐿Litalic_L as the Lipschitz constant.

Next, we connect stale population risk to the empirical risk.

Lemma 4.2.

Given a target encoder g𝑔gitalic_g, its stale approximation gsuperscript𝑔g^{\prime}italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, and a set of n𝑛nitalic_n-targets 𝒮~n={y1,,yn}subscript~𝒮𝑛subscript𝑦1subscript𝑦𝑛\tilde{\mathscr{S}}_{n}=\{y_{1},...,y_{n}\}over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = { italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_y start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } sampled from 𝒟~Ysubscript~𝒟𝑌\tilde{\mathscr{D}}_{Y}over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT,

R,ϕ~n(𝒟~Y)R,ϕ~n(𝒮~n)+𝒮~n(𝒢,),subscript𝑅subscript~italic-ϕ𝑛subscript~𝒟𝑌subscript𝑅subscript~italic-ϕ𝑛subscript~𝒮𝑛subscriptsubscript~𝒮𝑛subscript𝒢R_{\ell,\tilde{\phi}_{n}}(\tilde{\mathscr{D}}_{Y})\leq R_{\ell,\tilde{\phi}_{n%}}(\tilde{\mathscr{S}}_{n})+{\mathfrak{R}}_{\tilde{\mathscr{S}}_{n}}(\mathcal{%G}_{\ell,\mathcal{F}}),italic_R start_POSTSUBSCRIPT roman_ℓ , over~ start_ARG italic_ϕ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ) ≤ italic_R start_POSTSUBSCRIPT roman_ℓ , over~ start_ARG italic_ϕ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) + fraktur_R start_POSTSUBSCRIPT over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT ) ,(14)

where 𝒮~n(𝒢,)subscriptsubscript~𝒮𝑛subscript𝒢{\mathfrak{R}}_{\tilde{\mathscr{S}}_{n}}(\mathcal{G}_{\ell,\mathcal{F}})fraktur_R start_POSTSUBSCRIPT over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT ) is the Rademacher complexity of 𝒢,subscript𝒢\mathcal{G}_{\ell,\mathcal{F}}caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT.

Now, we can relate the complexity of function class gsuperscriptsuperscript𝑔\mathcal{F}^{g^{\prime}}caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT, number of samples n𝑛nitalic_n, and the discrepancy of the true g𝑔gitalic_g and stale approximate encoders gsuperscript𝑔g^{\prime}italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT:

Theorem 4.3.

For a target encoder, g𝑔gitalic_g, its stale approximation, gsuperscript𝑔g^{\prime}italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, and the Rademacher complexity ~n(𝒢,g)subscript~𝑛subscript𝒢superscriptsuperscript𝑔\tilde{\mathfrak{R}}_{n}(\mathcal{G}_{\ell,\mathcal{F}^{g^{\prime}}})over~ start_ARG fraktur_R end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ), the true population risk R,ϕ(𝒟Y)subscript𝑅italic-ϕsubscript𝒟𝑌R_{\ell,\phi}(\mathscr{D}_{Y})italic_R start_POSTSUBSCRIPT roman_ℓ , italic_ϕ end_POSTSUBSCRIPT ( script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ) is bounded by the following with probability at least 1δ1𝛿1-\delta1 - italic_δ:

R,ϕ(𝒟Y)subscript𝑅italic-ϕsubscript𝒟𝑌\displaystyle R_{\ell,\phi}(\mathscr{D}_{Y})italic_R start_POSTSUBSCRIPT roman_ℓ , italic_ϕ end_POSTSUBSCRIPT ( script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT )T1+T2+T3absentsubscript𝑇1subscript𝑇2subscript𝑇3\displaystyle\leq T_{1}+T_{2}+T_{3}≤ italic_T start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_T start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_T start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT(15)
T1subscript𝑇1\displaystyle T_{1}italic_T start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT=R,ϕ~n(𝒮~n)absentsubscript𝑅subscript~italic-ϕ𝑛subscript~𝒮𝑛\displaystyle=R_{\ell,\tilde{\phi}_{n}}(\tilde{\mathscr{S}}_{n})= italic_R start_POSTSUBSCRIPT roman_ℓ , over~ start_ARG italic_ϕ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT )
T2subscript𝑇2\displaystyle T_{2}italic_T start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT=𝒲(𝒟Y,𝒟~Y)Luabsent𝒲subscript𝒟𝑌subscript~𝒟𝑌𝐿norm𝑢\displaystyle={\mathcal{W}(\mathscr{D}_{Y},\tilde{\mathscr{D}}_{Y})}\ {\leq L%\|u\|}= caligraphic_W ( script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT , over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ) ≤ italic_L ∥ italic_u ∥
T3subscript𝑇3\displaystyle T_{3}italic_T start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT=~n(𝒢,g)+𝒪(log(1/δ)n)absentsubscript~𝑛subscript𝒢superscriptsuperscript𝑔𝒪1𝛿𝑛\displaystyle=\tilde{\mathfrak{R}}_{n}(\mathcal{G}_{\ell,\mathcal{F}^{g^{%\prime}}})+\mathcal{O}\Big{(}\sqrt{{\frac{\log(1/\delta)}{n}}}\Big{)}= over~ start_ARG fraktur_R end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) + caligraphic_O ( square-root start_ARG divide start_ARG roman_log ( 1 / italic_δ ) end_ARG start_ARG italic_n end_ARG end_ARG )

Note the following implications of these theoretical results:

1. If the corrector network hhitalic_h is too complicated or there are not enough samples n𝑛nitalic_n, then hhitalic_h overfits and T3subscript𝑇3T_{3}italic_T start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT, will dominate.

2. If g𝑔gitalic_g and gsuperscript𝑔g^{\prime}italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT are very different, then term T2subscript𝑇2T_{2}italic_T start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT will dominate.

3. If h()h(\cdot)italic_h ( ⋅ ) is too simple and we cannot fit the sampled data well, then T1subscript𝑇1T_{1}italic_T start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT will dominate.

We empirically explore some of these trade-offs in §5.3.

Improving Dense Retriever Training with Corrector Networks (2)

5 Experiments

We evaluate training using target corrector networks in two settings: supervised dense retrieval and retrieval augmented language models.We further investigate the properties of the target corrector networks in synthetic experiments.In summary, the experiments investigate whether training strategies can effectively obviate the need to keep cached buffers of targets up-to-date by re-embedding during training. We answer this question affirmatively with the following highlights:

No Re-embedding Needed Training using target corrector networks matches the task performance of exhaustively re-embed all targets every 500 steps throughout training in both dense retrieval (Table1) and retrieval augmented language models (Table3). Target correctors achieve this without ever needing to re-embed targets during training, yielding significant computational savings (Fig.2).

Best no re-embedding method Compared to frozen approaches, stale approaches, and Dynnibal without re-embedding, target corrector networks achieve over 10 point improvements in RLM tasks and 4 point improvements across multiple recall measures in retrieval.

Simpler and Less Computation Target correctors perform as well or better than Stochastic Negative Mining (SNM) (Reddi etal., 2019) despite SNM doing more re-embedding. Similarly, target corrector networks nearly match Dynnibal (Monath etal., 2023) when Dynnibal uses much more computation (Table1). Dynnibal is a much more complicated and difficult to implement method.

5.1 Supervised Dense Retrieval

Setting & Metrics  We evaluate training methods for supervised dense retrieval models. Each method is provided the same supervised data. All methods use a stale buffer of target representations and use this buffer to form the subset of targets, S(𝒴)𝑆𝒴S(\mathcal{Y})italic_S ( caligraphic_Y ), used in computing the truncated softmax. All methods use the same loss (cross-entropy) and optimize parameters of the dual-encoders using gradient descent. The methods differ in their maintenance of the buffer, and, as such, differ in their computational requirements of maintaining this buffer. We measure the computational requirements in terms of how many targets are re-embedding during training111Our JAX (Bradbury etal., 2018) implementation run on Cloud TPUv3 re-embeds ~2184 targets per second on each core.. We measure re-embedding in terms of the number of targets encoded to indicate the computational expense (even if wall clock time is mitigated using a complicated asynchronous computation). Re-embedding every target even one additional time during training can be problematic if number of targets is large. Furthermore, the initial buffer, created using the initial parameters of the dual-encoder (e.g., a pre-trained language model) can be computed once and used for subsequent training jobs.

Data  We evaluate on Natural Questions (Kwiatkowski etal., 2019) with over 21M targets (Wikipedia passages), about 60K training examples (question, passage pairs), and about 3K in dev/test, and MSMARCO (Bajaj etal., 2016) 8.8M targets (web passages), and 500K training examples.

Models  We initialize the dual encoder models with two publicly available pre-trained language models, GTR (Ni etal., 2022), and T5 (Raffel etal., 2020). GTR is an encoder model initialized from T5 and further pre-trained for dense retrieval on a large collection of corpora of question/answer pairs. For MSMARCO, we only use T5 since it is included in GTR’s training data. We use the base size models, D=768𝐷768D=768italic_D = 768, and train separate parameters for f(x)𝑓𝑥f(x)italic_f ( italic_x ) and g(y)𝑔𝑦g(y)italic_g ( italic_y ). For the target corrector, we use a two layer MLP. We use 8192 hidden units, a ReLU non-linearity, and a residual connection.

Re-embedNQ Dev - Recall (\uparrow)NQ Test - Recall (\uparrow)
Num. (\downarrow)@1@5@10@20@100@1@5@10@20@100
GTR-baseIn-batch017.1446.7758.7169.4585.5437.9264.7672.5478.2887.00
Stale033.1162.0470.3178.1389.3246.7668.6475.2180.6687.48
Dynnibal+028.7359.6670.0878.1490.1844.4067.5374.9380.2287.23
Corrector \faNewspaper[regular]\faNewspaper[regular]{}^{\text{\faNewspaper[regular]}}start_FLOATSUPERSCRIPT [regular] end_FLOATSUPERSCRIPT (msesubscriptmse\ell_{\text{mse}}roman_ℓ start_POSTSUBSCRIPT mse end_POSTSUBSCRIPT)034.9865.0374.0180.7790.8249.6170.7277.0482.3388.28
Corrector \faNewspaper[regular]\faNewspaper[regular]{}^{\text{\faNewspaper[regular]}}start_FLOATSUPERSCRIPT [regular] end_FLOATSUPERSCRIPT035.7866.7475.0681.5291.3750.6171.0077.7382.6688.39
Dynnibal+42M35.8666.5475.0481.4091.2750.5571.6978.2583.3588.73
SNM80M32.0364.0173.7281.3791.4749.1469.8977.1282.1987.95
Exhaustive1.68B36.2967.0875.5582.0791.7350.3071.5578.1282.8388.59
T5-baseIn-batch09.9328.0737.1745.5464.0623.4047.5056.3965.3477.97
Stale016.7936.8544.8251.7967.3527.6550.1959.2866.9878.95
Dynnibal +017.4239.6548.7557.3673.0329.7253.9963.3870.6180.94
Corrector \faNewspaper[regular]\faNewspaper[regular]{}^{\text{\faNewspaper[regular]}}start_FLOATSUPERSCRIPT [regular] end_FLOATSUPERSCRIPT023.6447.6956.6864.6579.0336.6559.2568.0673.7183.13
Dynnibal +42M23.7146.6355.7563.8879.4636.6559.3167.6574.4683.13
Dynnibal +80M24.7647.6956.8264.9080.1536.9059.9768.2374.5483.35
SNM 80M22.5546.8655.7264.1980.4035.9359.0667.4873.6682.85
Exhaustive1.68B24.7048.2157.1865.3979.9437.3460.4268.7074.7683.41

We compare the following approaches: Target Corrector Networks (this paper): At the first training step, we initialize the buffer with vector representations of every target. At every subsequent step, we use the target corrector network to produce a new representation of the targets, without running the target-encoder, simply by running our small MLP corrector on the stale representations. The stale buffer representations are never updated during training. Stale: We initialize the buffer of targets at the first step of training and do not update it throughout training. We experimented with both freezing the target encoder parameters g(y)𝑔𝑦g(y)italic_g ( italic_y ) and allowing them to be updated despite the stale buffer. We found updating the parameters to be slightly better and report those results. Exhaustive: We exhaustively re-embed all of the targets in the buffer every 500 steps of training. Stochastic Negative Mining (SNM; Reddi etal. 2019): Instead of storing every target in the buffer, we store a subset of targets sampled uniformly at random. We re-sample and re-embed this buffer every 500 steps. We use a buffer size of 1M targets. Dynnibal (Monath etal., 2023): This complicated approach maintains a buffer using a low-rank regression model as a part of tree index structure. The regression model is updated every 500 steps on a sub-sample of targets, unlike our approach which is trained jointly. Furthermore, to get good performance, Dynnibal performs costly full buffer re-embedding periodically throughout training. We needed to perform two such re-embeddings. Dynnibal with fewer refreshes does not perform as well.

HiddenSteps/secR@1R@20R@100
Units
Exhaustive-0.4336.2982.0791.70
Corrector81920.8335.7881.5291.37
Corrector40961.1035.5381.6391.07
Corrector20481.8335.5581.0791.08

In Table1, ourtarget corrector network approach greatly improves upon the stale approach, especially in Recall@1, 5, 10. We observe a nearly 5 point improvement at R@10 in the dev set and a 4 point improvement in R@1 on the test over the stale approach. Our approach nearly matches the performance of the computationally intensive exhaustive approach. Furthermore, we perform comparably to the more expensive SNM and Dynnibal methods. We perform better than Dynnibal for the same amount of re-embedding. While doubling the number of index refreshes may appear negligible, having to re-embed the buffer during training can be computationally burdensome, especially as the number of targets grows. Using a buffer created from the initial parameters of the dual-encoder as with our approach, allows the buffer to be constructed once ahead of time and re-used across both training and tasks. Dynnibal requires hand tuning to get the re-embedding schedule correct.

Table1 also compares dual-encoder initialization. GTR is pre-trained for retrieval and hence achieves better results. T5 is not pre-trained for retrieval and requires more adaptation for the retrieval task. We observe that SNM struggles more to match the performance of Exhaustive with T5. Furthermore, Dynnibal requires more full index refreshes to get competitive results. Our method is able to achieve nearly as good results as the Exhaustive approach and Dynnibal (with re-embedding) despite never needing to re-embed.

We also report timing comparisons in terms of steps-per-second between corrector networks (of two sizes) and exhaustive re-encoding of the targets. These can be found in table 2. We can see that both small and large corrector networks lead large speed gains over exhaustive re-encoding with minimal performance gains. This indicates that corrector networks can have practical training time efficiency gains over exhaustive methods.

See AppendixB.2 for additional results (MSMARCO, other ablations) and further discussion.

Re-embedRetr.NQTQAHPQA
Num. (\downarrow)
No Retr.0-25.426.114.5
Frozen Retr.0GTR48.455.128.0
Corrector0GTR52.366.436.7
Exhaustive1.1BGTR52.466.533.8
Frozen Retr.0T513.3412.1513.37
Corrector0T548.163.7321.97
Exhaustive1.1BT548.366.0325.45

5.2 Retrieval Augmented Language Models

Setting & Metrics  We evaluate the latent variable use case of training the retriever in a retrieval-augmented language model (RLM), as described in Section3.1. We will compare approaches for training in terms of their re-embedding costs.

Datasets  We evaluate on the three question answering datasets: TriviaQA (Joshi etal., 2017), NQOpen (Kwiatkowski etal., 2019), and HotPotQA (Yang etal., 2018). We use 256 token passages from a 2018 Wikipedia snapshot as the collection of targets, 𝒴𝒴\mathcal{Y}caligraphic_Y, with 28M targets.

Models  We initialize the retriever with GTR-base or T5-base and use T5-base as the reader in Fusion-In-Decoder (Izacard & Grave, 2021). We use 32 retrieved documents in all experiments. The target corrector is a two-layer MLP.

We compare the following approaches: Target Corrector Network: Target corrector is used to retrieve S(𝒴)𝑆𝒴S(\mathcal{Y})italic_S ( caligraphic_Y ) at training time. We embed the targets at the beginning of training and never update the buffer. No Retrieval (Roberts etal., 2020): The retriever is not used. The reader model is trained on the dataset and uses only its parameters to answer the questions. Frozen Retrieval (Izacard & Grave, 2021): Every target is embedded once at the beginning of training. Only the parameters of the reader model are trained (updating the retriever parameters did not improve performance). Exhaustive: Jointly training the retriever and reader, we exhaustively re-embed all 28M targets every 500 steps.

In Table3, we report exact match accuracy on the held-out validation sets. Our proposed target corrector matches or nearly matches the performance of the exhaustive re-embedding approach without ever having to re-embed the buffer. This is a dramatic reduction in computational cost, as the exhaustive approach ends up embedding all 28M passages 40 times (1.1B re-embeddings). Target correctors greatly outperform the approaches that do not use retrieval (by more than 20 points) and the frozen retriever approach (by at least 4 points and by up to 10 points).

Improving Dense Retriever Training with Corrector Networks (3)

5.3 Synthetic Experiments

In these experiments, we measure the ability of proposedcorrector network to approximate categorical distributions parameterized by the softmax by training the corrector network, hhitalic_h, without training parameters of the dual-encoder.

Setting & Metrics We will measure the ability of proposedcorrector network to approximate categorical distributions parameterized by the softmax. We do so by training the corrector network, hhitalic_h, in isolation, e.g., only training the parameters of the corrector network, ΨΨ\Psiroman_Ψ, without training parameters of the dual-encoder for a particular task. We measure the quality of approximation using the KL-divergence between the true categorical distribution P(y|x)𝑃conditional𝑦𝑥P(y|x)italic_P ( italic_y | italic_x ) (Equation1) and the approximate distribution given by the corrector network Ph(y|x)subscript𝑃conditional𝑦𝑥P_{h}(y|x)italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_y | italic_x ) (Equation4). We measure the complexity of the corrector network by its parameter count, |Ψ|Ψ|\Psi|| roman_Ψ |. We measure staleness, i.e., the difficulty of correcting a set of stale representations, by the KL-divergence between the true categorical distribution P(y|x)𝑃conditional𝑦𝑥P(y|x)italic_P ( italic_y | italic_x ) and the distribution Pg(y|x)exp(βf(x),g(y))proportional-tosubscript𝑃superscript𝑔conditional𝑦𝑥𝛽𝑓𝑥superscript𝑔𝑦P_{g^{\prime}}(y|x)\propto\exp(\beta\langle f(x),g^{\prime}(y)\rangle)italic_P start_POSTSUBSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_y | italic_x ) ∝ roman_exp ( italic_β ⟨ italic_f ( italic_x ) , italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) ⟩ ).

Improving Dense Retriever Training with Corrector Networks (4)

Data Generation We directly generate vector representations corresponding to data points and targets. That is, rather than having a dual-encoder model provide the vector representation of a data point or target, we directly generate synthetic data corresponding to f(x)𝑓𝑥f(x)italic_f ( italic_x ), g(y)𝑔𝑦g(y)italic_g ( italic_y ), and g(y)superscript𝑔𝑦g^{\prime}(y)italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ). We generate 4096 targets in D=8𝐷8D=8italic_D = 8 dimensions from a mixture of 20 Gaussians to represent g(y)superscript𝑔𝑦g^{\prime}(y)italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ). To generate g(y)𝑔𝑦g(y)italic_g ( italic_y ), we transform g(y)superscript𝑔𝑦g^{\prime}(y)italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) by feeding the points into randomly initialized MLPs with up to 2222 hidden layers of size D,2D,4D𝐷2𝐷4𝐷D,2D,4Ditalic_D , 2 italic_D , 4 italic_D or 8D8𝐷8D8 italic_D, with RELU activation and residual connections. We vary the complexity of the MLP and variance of the initialization to create embeddings g(y)𝑔𝑦g(y)italic_g ( italic_y ) to model a variety of settings of the extent of the staleness (𝒲(𝒟Y,𝒟~Y)𝒲subscript𝒟𝑌subscript~𝒟𝑌{\mathcal{W}(\mathscr{D}_{Y},\tilde{\mathscr{D}}_{Y})}caligraphic_W ( script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT , over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT )).

Corrector Network In these experiments, we vary the parameter count of the corrector network hhitalic_h and number of hidden layers, using between 00 and 2222 hidden layers with hidden dimension of D,2D,4D,𝐷2𝐷4𝐷D,2D,4D,italic_D , 2 italic_D , 4 italic_D , or 8D8𝐷8D8 italic_D. We use ReLU nonlinearity with residual connections. We optimize the parameters of the corrector network using Adam with learning rate 0.030.030.030.03, and stop when the loss has not improved for at least 100100100100 epochs or we reach 1000100010001000 epochs of training.

Varying |S(𝒴)|𝑆𝒴|S(\mathcal{Y})|| italic_S ( caligraphic_Y ) |, number of targets used for training In Figure3, we explore trade-offs between the complexity in terms of the parameter count |Ψ|Ψ|\Psi|| roman_Ψ | of hhitalic_h (x-axis); the approximation error KL(PPh)KLconditional𝑃subscript𝑃\operatorname{KL}(P\|P_{h})roman_KL ( italic_P ∥ italic_P start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ) after applying the trained correction model (y-axis); and the fraction of samples used for training hhitalic_h. We report the complexity of the transformation from gsuperscript𝑔g^{\prime}italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT to g𝑔gitalic_g in terms of KL(PPg)KLconditional𝑃subscript𝑃superscript𝑔\operatorname{KL}(P\|P_{g^{\prime}})roman_KL ( italic_P ∥ italic_P start_POSTSUBSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) above each pane.Using a higher fraction of training samples is needed when there is more staleness. When the drift is more significant (right-hand pane), we observe that using increased parameters with a smaller fraction of samples does lead to overfitting. In this setting, it seems that sampling 10%percent1010\%10 % of the targets is generally sufficient.

Varying Complexity of the Target Corrector NetworkIn order to explore how the KL divergence of our approximation may change with respect to the staleness of the embeddings gsuperscript𝑔g^{\prime}italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, we train our embedding model to approximate the distributions P𝑃Pitalic_P.In Figure 4, we explore how the KL divergence of our approximation may change with respect to the staleness of the embeddings gsuperscript𝑔g^{\prime}italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT,We can obtain a significant reduction in KL divergence via the correction model (on the y-axis) across a wide variety of drifts (as measured by KL(PPg)KLconditional𝑃subscript𝑃superscript𝑔\operatorname{KL}(P\|P_{g^{\prime}})roman_KL ( italic_P ∥ italic_P start_POSTSUBSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT )). Increasing parameter count is always effective, but it yields greater benefit when approximating a distribution with greater divergence.

6 Related work

Energy-based Models  Many similar ideas of training small parametric models to aid the training of other models has been widely studied in energy-based models, such as CoopNets (Xie etal., 2018), VERA (Grathwohl etal., 2021), and others (Grathwohl etal., 2020). In this setting models can be trained to skirt around intractable computations required in main-model training.

Amortized Inference  There are many approaches that speed up sampling by fitting parametric models such as feed-forward neural networks (Marino etal., 2018; Naderiparizi etal., 2022).

Softmax Approximations  Previous work has considered approximations to softmax via kernel methods (Blanc & Rendle, 2018; Rawat etal., 2019) when there are trainable parameters for every target (rather than an encoder). Sampling-based approaches are widely used as well (Vembu etal., 2009; Zaheer etal., 2017; Monath etal., 2023).

Adapters  Adapter methods, which train small parametric components of larger networks (Houlsby etal., 2019) bear resemblance to our approach.However, our approach is distinct in that it operates only on the output layer of the neural models, not intermediate layers.

7 Conclusion

We present target corrector networks for approximating the softmax function during the training of dual encoder models and retrieval augmented language models. The target corrector networks learn to update a stale buffer of target representations. We investigate the generalization properties of the corrector models theoretically. We furthermore show empirically how our correct model approach can be used to train models (both supervised retrievers and retrieval augmented language models) matching the accuracy of models that use 4x-80x the computational budget during training.

Impact Statement

Our work proposes new more efficient ways of training of retrieval models.Retrieval models both in their own right and in combination with language modelshave wide and applicable uses. The techniques of this paper are about improvingtraining efficiency. As such, better models could be produced faster, bringing to bearall the responsibilities of model creators in terms of understanding the successes,limitations, and biases of the model. Future work could consider the question ofhow different training strategies affect the way in which retrieval modelshave broad impact. Of particular interest to this paper could be the way inwhich staleness when computing the truncated softmax plays a role in such a study.

References

  • Bajaj etal. (2016)Bajaj, P., Campos, D., Craswell, N., Deng, L., Gao, J., Liu, X., Majumder, R.,McNamara, A., Mitra, B., Nguyen, T., etal.Ms marco: A human generated machine reading comprehension dataset.arXiv preprint arXiv:1611.09268, 2016.
  • Blanc & Rendle (2018)Blanc, G. and Rendle, S.Adaptive sampled softmax with kernel based sampling.International Conference on Machine Learning (ICML), 2018.
  • Bradbury etal. (2018)Bradbury, J., Frostig, R., Hawkins, P., Johnson, M.J., Leary, C., Maclaurin,D., Necula, G., Paszke, A., VanderPlas, J., Wanderman-Milne, S., andZhang, Q.JAX: composable transformations of Python+NumPy programs.2018.
  • Devlin etal. (2019)Devlin, J., Chang, M.-W., Lee, K., and Toutanova, K.BERT: Pre-training of deep bidirectional transformers for languageunderstanding.Proceedings of Conference of the North American Chapter ofthe Association for Computational Linguistics: Human Language Technologies(NAACL-HLT), 2019.
  • Devroye etal. (2013)Devroye, L., Györfi, L., and Lugosi, G.A probabilistic theory of pattern recognition, volume31.Springer Science & Business Media, 2013.
  • Dulac-Arnold etal. (2015)Dulac-Arnold, G., Evans, R., van Hasselt, H., Sunehag, P., Lillicrap, T., Hunt,J., Mann, T., Weber, T., Degris, T., and Coppin, B.Deep reinforcement learning in large discrete action spaces.arXiv preprint arXiv:1512.07679, 2015.
  • Gillick etal. (2019)Gillick, D., Kulkarni, S., Lansing, L., Presta, A., Baldridge, J., Ie, E., andGarcia-Olano, D.Learning dense representations for entity retrieval.Conference on Computational Natural Language Learning (CoNLL),2019.
  • Gottipati etal. (2020)Gottipati, S.K., Sattarov, B., Niu, S., Pathak, Y., Wei, H., Liu, S.,Blackburn, S., Thomas, K., Coley, C., Tang, J., etal.Learning to navigate the synthetically accessible chemical spaceusing reinforcement learning.International conference on machine learning, 2020.
  • Grathwohl etal. (2020)Grathwohl, W., Wang, K.-C., Jacobsen, J.-H., Duvenaud, D., and Zemel, R.Learning the stein discrepancy for training and evaluatingenergy-based models without sampling.International Conference on Machine Learning, 2020.
  • Grathwohl etal. (2021)Grathwohl, W., Kelly, J., Hashemi, M., Norouzi, M., Swersky, K., and Duvenaud,D.No mcmc for me: Amortized sampling for fast and stable training ofenergy-based models.ICLR, 2021.
  • Guu etal. (2020)Guu, K., Lee, K., Tung, Z., Pasupat, P., and Chang, M.Retrieval augmented language model pre-training.International Conference on Machine Learning (ICML), 2020.
  • Han etal. (2020)Han, T., Nijkamp, E., Zhou, L., Pang, B., Zhu, S.-C., and Wu, Y.N.Joint training of variational auto-encoder and latent energy-basedmodel.Proceedings of the IEEE/CVF Conference on Computer Vision andPattern Recognition, 2020.
  • Houlsby etal. (2019)Houlsby, N., Giurgiu, A., Jastrzebski, S., Morrone, B., DeLaroussilhe, Q.,Gesmundo, A., Attariyan, M., and Gelly, S.Parameter-efficient transfer learning for NLP.Proceedings of the 36th International Conference on MachineLearning, 2019.
  • Izacard & Grave (2021)Izacard, G. and Grave, E.Leveraging passage retrieval with generative models for open domainquestion answering, 2021.
  • Izacard etal. (2022)Izacard, G., Lewis, P., Lomeli, M., Hosseini, L., Petroni, F., Schick, T.,Dwivedi-Yu, J., Joulin, A., Riedel, S., and Grave, E.Few-shot learning with retrieval augmented language models.arXiv preprint arXiv:2208.03299, 2022.
  • Joshi etal. (2017)Joshi, M., Choi, E., Weld, D.S., and Zettlemoyer, L.Triviaqa: A large scale distantly supervised challenge dataset forreading comprehension.Proceedings of the 55th Annual Meeting of the Association forComputational Linguistics (Volume 1: Long Papers), 2017.
  • Karpukhin etal. (2020)Karpukhin, V., Oğuz, B., Min, S., Lewis, P., Wu, L., Edunov, S., Chen,D., and Yih, W.-t.Dense passage retrieval for open-domain question answering.arXiv preprint arXiv:2004.04906, 2020.
  • Kingma & Ba (2014)Kingma, D.P. and Ba, J.Adam: A method for stochastic optimization.arXiv preprint arXiv:1412.6980, 2014.
  • Kwiatkowski etal. (2019)Kwiatkowski, T., Palomaki, J., Redfield, O., Collins, M., Parikh, A., Alberti,C., Epstein, D., Polosukhin, I., Devlin, J., Lee, K., Toutanova, K., Jones,L., Kelcey, M., Chang, M.-W., Dai, A.M., Uszkoreit, J., Le, Q., and Petrov,S.Natural Questions: A Benchmark for Question Answering Research.Transactions of the Association for Computational Linguistics(TACL), 2019.
  • Lindgren etal. (2021)Lindgren, E., Reddi, S.J., Guo, R., and Kumar, S.Efficient training of retrieval models using negative cache.Advances in Neural Information Processing Systems (NeurIPS),2021.
  • Logeswaran etal. (2019)Logeswaran, L., Chang, M.-W., Lee, K., Toutanova, K., Devlin, J., and Lee, H.Zero-shot entity linking by reading entity descriptions.Association for Computational Linguistics (ACL), 2019.
  • Marino etal. (2018)Marino, J., Yue, Y., and Mandt, S.Iterative amortized inference.Proceedings of the 35th International Conference on MachineLearning, pp. 3403–3412, 2018.
  • Mohri etal. (2018)Mohri, M., Rostamizadeh, A., and Talwalkar, A.Foundations of machine learning.MIT press, 2018.
  • Monath etal. (2023)Monath, N., Zaheer, M., Allen, K., and McCallum, A.Improving dual-encoder training through dynamic indexes for negativemining.AISTATS, 2023.
  • Naderiparizi etal. (2022)Naderiparizi, S., Scibior, A., Munk, A., Ghadiri, M., Baydin, A.G.,Gram-Hansen, B.J., DeWitt, C. A.S., Zinkov, R., Torr, P., Rainforth, T.,etal.Amortized rejection sampling in universal probabilistic programming.International Conference on Artificial Intelligence andStatistics, 2022.
  • Ni etal. (2021)Ni, J., Qu, C., Lu, J., Dai, Z., Ábrego, G.H., Ma, J., Zhao, V.Y., Luan,Y., Hall, K.B., Chang, M.-W., etal.Large dual encoders are generalizable retrievers.arXiv preprint arXiv:2112.07899, 2021.
  • Ni etal. (2022)Ni, J., Qu, C., Lu, J., Dai, Z., Abrego, G.H., Ma, J., Zhao, V., Luan, Y.,Hall, K., Chang, M.-W., etal.Large dual encoders are generalizable retrievers.Proceedings of the 2022 Conference on Empirical Methods inNatural Language Processing, 2022.
  • Qu etal. (2021)Qu, Y., Ding, Y., Liu, J., Liu, K., Ren, R., Zhao, W.X., Dong, D., Wu, H., andWang, H.Rocketqa: An optimized training approach to dense passage retrievalfor open-domain question answering.Conference of the North American Chapter of the Association forComputational Linguistics: Human Language Technologies (NAACL-HLT), 2021.
  • Raffel etal. (2020)Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S., Matena, M., Zhou,Y., Li, W., and Liu, P.J.Exploring the limits of transfer learning with a unified text-to-texttransformer.The Journal of Machine Learning Research, 21(1),2020.
  • Rawat etal. (2019)Rawat, A.S., Chen, J., Yu, F., Suresh, A.T., and Kumar, S.Sampled softmax with random fourier features.Advances in Neural Information Processing Systems (NeurIPS),2019.
  • Rawat etal. (2020)Rawat, A.S., Menon, A.K., Veit, A., Yu, F., Reddi, S.J., and Kumar, S.Doubly-stochastic mining for heterogeneous retrieval.arXiv preprint arXiv:2004.10915, 2020.
  • Reddi etal. (2019)Reddi, S.J., Kale, S., Yu, F., Holtmann-Rice, D., Chen, J., and Kumar, S.Stochastic negative mining for learning with large output spaces.Proceedings of the Twenty-Second International Conference onArtificial Intelligence and Statistics, 2019.
  • Roberts etal. (2020)Roberts, A., Raffel, C., and Shazeer, N.How much knowledge can you pack into the parameters of a languagemodel?arXiv preprint arXiv:2002.08910, 2020.
  • Thakur etal. (2021)Thakur, N., Reimers, N., Rücklé, A., Srivastava, A., and Gurevych, I.Beir: A heterogenous benchmark for zero-shot evaluation ofinformation retrieval models.arXiv preprint arXiv:2104.08663, 2021.
  • Vembu etal. (2009)Vembu, S., Gärtner, T., and Boley, M.Probabilistic structured predictors.Uncertainty in Artificial Intelligence (UAI), 2009.
  • Villani etal. (2008)Villani, C. etal.Optimal transport: old and new.Springer, 2008.
  • Wachsmuth etal. (2018)Wachsmuth, H., Syed, S., and Stein, B.Retrieval of the best counterargument without prior topic knowledge.Proceedings of the 56th Annual Meeting of the Association forComputational Linguistics (Volume 1: Long Papers), July 2018.
  • Xie etal. (2018)Xie, J., Lu, Y., Gao, R., and Wu, Y.N.Cooperative learning of energy-based model and latent variable modelvia mcmc teaching.Proceedings of the AAAI Conference on Artificial Intelligence,32, 2018.
  • Xiong etal. (2020)Xiong, L., Xiong, C., Li, Y., Tang, K.-F., Liu, J., Bennett, P.N., Ahmed, J.,and Overwijk, A.Approximate nearest neighbor negative contrastive learning for densetext retrieval.International Conference on Learning Representations (ICLR),2020.
  • Yang etal. (2018)Yang, Z., Qi, P., Zhang, S., Bengio, Y., Cohen, W., Salakhutdinov, R., andManning, C.D.Hotpotqa: A dataset for diverse, explainable multi-hop questionanswering.Proceedings of the 2018 Conference on Empirical Methods inNatural Language Processing, 2018.
  • Yu etal. (2022)Yu, H.-F., Zhong, K., Zhang, J., Chang, W.-C., and Dhillon, I.S.Pecos: Prediction for enormous and correlated output spaces.Journal of Machine Learning Research, 2022.
  • Zaheer etal. (2017)Zaheer, M., Kottur, S., Ahmed, A., Moura, J., and Smola, A.Canopy fast sampling with cover trees.International Conference on Machine Learning (ICML), 2017.

Appendix A Analysis: Proofs

See 4.1Proof. We bound the gap between true population risk and stale population risk.Recall that 𝒢,subscript𝒢\mathcal{G}_{\ell,\mathcal{F}}caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT is the induced function class: 𝒢,={y(f(y),g(y)):f}.subscript𝒢conditional-setmaps-to𝑦𝑓𝑦𝑔𝑦𝑓\mathcal{G}_{\ell,\mathcal{F}}=\{y\mapsto\ell(f(y),g(y)):f\in\mathcal{F}\}.caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT = { italic_y ↦ roman_ℓ ( italic_f ( italic_y ) , italic_g ( italic_y ) ) : italic_f ∈ caligraphic_F } .Now note that

R,f(𝒟Y)R,f(𝒟~Y)subscript𝑅𝑓subscript𝒟𝑌subscript𝑅𝑓subscript~𝒟𝑌\displaystyle R_{\ell,f}(\mathscr{D}_{Y})-R_{\ell,f}(\tilde{\mathscr{D}}_{Y})italic_R start_POSTSUBSCRIPT roman_ℓ , italic_f end_POSTSUBSCRIPT ( script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ) - italic_R start_POSTSUBSCRIPT roman_ℓ , italic_f end_POSTSUBSCRIPT ( over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT )(16)
=𝔼Y𝒟Y[(f(Y),g(Y)]𝔼Y𝒟~Y[(f(Y),g(Y)]\displaystyle=\mathbb{E}_{Y\sim\mathscr{D}_{Y}}\left[\ell(f(Y),g(Y)\right]-%\mathbb{E}_{Y^{\prime}\sim\tilde{\mathscr{D}}_{Y}}\left[\ell(f(Y^{\prime}),g(Y%^{\prime})\right]= blackboard_E start_POSTSUBSCRIPT italic_Y ∼ script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_ℓ ( italic_f ( italic_Y ) , italic_g ( italic_Y ) ] - blackboard_E start_POSTSUBSCRIPT italic_Y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∼ over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_ℓ ( italic_f ( italic_Y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , italic_g ( italic_Y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ]
supλ𝒢,(𝔼Y𝒟Y[λ(X)]𝔼Y𝒟~Y[λ(Y)])absentsubscriptsupremum𝜆subscript𝒢subscript𝔼similar-to𝑌subscript𝒟𝑌delimited-[]𝜆𝑋subscript𝔼similar-tosuperscript𝑌subscript~𝒟𝑌delimited-[]𝜆superscript𝑌\displaystyle\leq\sup_{\lambda\in\mathcal{G}_{\ell,\mathcal{F}}}\left(\mathbb{%E}_{Y\sim\mathscr{D}_{Y}}\left[\lambda(X)\right]-\mathbb{E}_{Y^{\prime}\sim%\tilde{\mathscr{D}}_{Y}}\left[\lambda(Y^{\prime})\right]\right)≤ roman_sup start_POSTSUBSCRIPT italic_λ ∈ caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( blackboard_E start_POSTSUBSCRIPT italic_Y ∼ script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_λ ( italic_X ) ] - blackboard_E start_POSTSUBSCRIPT italic_Y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∼ over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_λ ( italic_Y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ] )
=(i)Lsupλ𝒢,(𝔼Y𝒟Y[λ(Y)L]𝔼Y𝒟~Y[λ(Y)L])𝑖𝐿subscriptsupremum𝜆subscript𝒢subscript𝔼similar-to𝑌subscript𝒟𝑌delimited-[]𝜆𝑌𝐿subscript𝔼similar-tosuperscript𝑌subscript~𝒟𝑌delimited-[]𝜆superscript𝑌𝐿\displaystyle\overset{(i)}{=}L\cdot\sup_{\lambda\in\mathcal{G}_{\ell,\mathcal{%F}}}\left(\mathbb{E}_{Y\sim\mathscr{D}_{Y}}\left[\frac{\lambda(Y)}{L}\right]-%\mathbb{E}_{Y^{\prime}\sim\tilde{\mathscr{D}}_{Y}}\left[\frac{\lambda(Y^{%\prime})}{L}\right]\right)start_OVERACCENT ( italic_i ) end_OVERACCENT start_ARG = end_ARG italic_L ⋅ roman_sup start_POSTSUBSCRIPT italic_λ ∈ caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( blackboard_E start_POSTSUBSCRIPT italic_Y ∼ script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ divide start_ARG italic_λ ( italic_Y ) end_ARG start_ARG italic_L end_ARG ] - blackboard_E start_POSTSUBSCRIPT italic_Y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∼ over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ divide start_ARG italic_λ ( italic_Y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_ARG start_ARG italic_L end_ARG ] )
(ii)LsupλLip1(ρ)(𝔼X𝒟Y[λ(X)]𝔼Y𝒟~Y[λ(Y)]),𝑖𝑖𝐿subscriptsupremum𝜆subscriptLip1𝜌subscript𝔼similar-to𝑋subscript𝒟𝑌delimited-[]𝜆𝑋subscript𝔼similar-tosuperscript𝑌subscript~𝒟𝑌delimited-[]𝜆superscript𝑌\displaystyle\overset{(ii)}{\leq}L\cdot\sup_{\lambda\in{\rm Lip}_{1}(\rho)}%\left(\mathbb{E}_{X\sim\mathscr{D}_{Y}}\left[\lambda(X)\right]-\mathbb{E}_{Y^{%\prime}\sim\tilde{\mathscr{D}}_{Y}}\left[\lambda(Y^{\prime})\right]\right),start_OVERACCENT ( italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG italic_L ⋅ roman_sup start_POSTSUBSCRIPT italic_λ ∈ roman_Lip start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_ρ ) end_POSTSUBSCRIPT ( blackboard_E start_POSTSUBSCRIPT italic_X ∼ script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_λ ( italic_X ) ] - blackboard_E start_POSTSUBSCRIPT italic_Y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∼ over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_λ ( italic_Y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ] ) ,
=(iii)𝒲(𝒟Y,𝒟~Y)𝑖𝑖𝑖𝒲subscript𝒟𝑌subscript~𝒟𝑌\displaystyle\overset{(iii)}{=}\mathcal{W}(\mathscr{D}_{Y},\tilde{\mathscr{D}}%_{Y})start_OVERACCENT ( italic_i italic_i italic_i ) end_OVERACCENT start_ARG = end_ARG caligraphic_W ( script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT , over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT )(17)
(iv)𝒟Y𝒟~YTV𝑖𝑣subscriptnormsubscript𝒟𝑌subscript~𝒟𝑌𝑇𝑉\displaystyle\overset{(iv)}{\leq}\|\mathscr{D}_{Y}-\tilde{\mathscr{D}}_{Y}\|_{TV}start_OVERACCENT ( italic_i italic_v ) end_OVERACCENT start_ARG ≤ end_ARG ∥ script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT - over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_T italic_V end_POSTSUBSCRIPT(18)
=(v)12y𝒴|𝗌𝗈𝖿𝗍𝗆𝖺𝗑(g(y))𝗌𝗈𝖿𝗍𝗆𝖺𝗑(g(y))|𝑣12subscript𝑦𝒴𝗌𝗈𝖿𝗍𝗆𝖺𝗑𝑔𝑦𝗌𝗈𝖿𝗍𝗆𝖺𝗑superscript𝑔𝑦\displaystyle\overset{(v)}{=}\frac{1}{2}\sum_{y\in\mathcal{Y}}|\mathsf{softmax%}(g(y))-\mathsf{softmax}(g^{\prime}(y))|start_OVERACCENT ( italic_v ) end_OVERACCENT start_ARG = end_ARG divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_y ∈ caligraphic_Y end_POSTSUBSCRIPT | sansserif_softmax ( italic_g ( italic_y ) ) - sansserif_softmax ( italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) ) |(19)
(vi)12gg1𝑣𝑖12subscriptnorm𝑔superscript𝑔1\displaystyle\overset{(vi)}{\leq}\frac{1}{2}\|g-g^{\prime}\|_{1}start_OVERACCENT ( italic_v italic_i ) end_OVERACCENT start_ARG ≤ end_ARG divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ italic_g - italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT(20)

where (i)𝑖(i)( italic_i ) follows by dividing and multiply by L𝐿Litalic_L; (ii)𝑖𝑖(ii)( italic_i italic_i ) follows as, for any λ𝒢,h𝜆subscriptsuperscript𝒢\lambda\in\mathcal{G}^{h}_{\ell,\mathcal{F}}italic_λ ∈ caligraphic_G start_POSTSUPERSCRIPT italic_h end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT, we have λL𝜆𝐿\frac{\lambda}{L}divide start_ARG italic_λ end_ARG start_ARG italic_L end_ARG to be 1111-Lipschitz; (iii)𝑖𝑖𝑖(iii)( italic_i italic_i italic_i ) follows from Kantorovich-Rubinstein duality(Villani etal., 2008); (iv)𝑖𝑣(iv)( italic_i italic_v ) follows from Corollory 6.14 in Villani etal. (2008); (v)𝑣(v)( italic_v ) follows from definition; and (vi)𝑣𝑖(vi)( italic_v italic_i ) follows from softmax Lipschtiz constant being 1.As g𝑔gitalic_g and gsuperscript𝑔g^{\prime}italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT are output from the same neural network but with parameters perturbed by u𝑢uitalic_u, then it follows that gg1Lusubscriptnorm𝑔superscript𝑔1𝐿norm𝑢\|g-g^{\prime}\|_{1}\leq L\|u\|∥ italic_g - italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ italic_L ∥ italic_u ∥.

See 4.2

Proof We need to connect the stale population risk to the empirical risk we are actually minimizing:

R,f~n(𝒟~Y)subscript𝑅subscript~𝑓𝑛subscript~𝒟𝑌\displaystyle R_{\ell,\tilde{f}_{n}}(\tilde{\mathscr{D}}_{Y})italic_R start_POSTSUBSCRIPT roman_ℓ , over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT )=𝔼𝒟~Y[(f~n(Y),g(Y))]absentsubscript𝔼subscript~𝒟𝑌delimited-[]subscript~𝑓𝑛𝑌𝑔𝑌\displaystyle=\mathbb{E}_{\tilde{\mathscr{D}}_{Y}}[\ell(\tilde{f}_{n}(Y),g(Y))]= blackboard_E start_POSTSUBSCRIPT over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_ℓ ( over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_Y ) , italic_g ( italic_Y ) ) ]
𝔼𝒮~n[(f~n(Y),g(Y))]+supf|𝔼𝒟~Y[(f(Y),g(Y))]𝔼𝒮~n[(f(Y),g(Y))]|absentsubscript𝔼subscript~𝒮𝑛delimited-[]subscript~𝑓𝑛𝑌𝑔𝑌subscriptsupremum𝑓subscript𝔼subscript~𝒟𝑌delimited-[]𝑓𝑌𝑔𝑌subscript𝔼subscript~𝒮𝑛delimited-[]𝑓𝑌𝑔𝑌\displaystyle\leq\mathbb{E}_{\tilde{\mathscr{S}}_{n}}[\ell(\tilde{f}_{n}(Y),g(%Y))]+\sup_{f\in\mathcal{F}}\big{|}\mathbb{E}_{\tilde{\mathscr{D}}_{Y}}[\ell(f(%Y),g(Y))]-\mathbb{E}_{\tilde{\mathscr{S}}_{n}}[\ell(f(Y),g(Y))]\big{|}≤ blackboard_E start_POSTSUBSCRIPT over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_ℓ ( over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_Y ) , italic_g ( italic_Y ) ) ] + roman_sup start_POSTSUBSCRIPT italic_f ∈ caligraphic_F end_POSTSUBSCRIPT | blackboard_E start_POSTSUBSCRIPT over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_ℓ ( italic_f ( italic_Y ) , italic_g ( italic_Y ) ) ] - blackboard_E start_POSTSUBSCRIPT over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_ℓ ( italic_f ( italic_Y ) , italic_g ( italic_Y ) ) ] |
=(i)𝔼𝒮~n[(f~n(Y),h(X))]+supg𝒢,|𝔼𝒟~Y[g(Y)]𝔼𝒮~n[g(Y)]|𝑖subscript𝔼subscript~𝒮𝑛delimited-[]subscript~𝑓𝑛𝑌𝑋subscriptsupremum𝑔subscript𝒢subscript𝔼subscript~𝒟𝑌delimited-[]𝑔𝑌subscript𝔼subscript~𝒮𝑛delimited-[]𝑔𝑌\displaystyle\overset{(i)}{=}\mathbb{E}_{\tilde{\mathscr{S}}_{n}}[\ell(\tilde{%f}_{n}(Y),h(X))]+\sup_{g\in\mathcal{G}_{\ell,\mathcal{F}}}\Big{|}\mathbb{E}_{%\tilde{\mathscr{D}}_{Y}}[g(Y)]-\mathbb{E}_{\tilde{\mathscr{S}}_{n}}[g(Y)]\Big{|}start_OVERACCENT ( italic_i ) end_OVERACCENT start_ARG = end_ARG blackboard_E start_POSTSUBSCRIPT over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_ℓ ( over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_Y ) , italic_h ( italic_X ) ) ] + roman_sup start_POSTSUBSCRIPT italic_g ∈ caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT end_POSTSUBSCRIPT | blackboard_E start_POSTSUBSCRIPT over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_g ( italic_Y ) ] - blackboard_E start_POSTSUBSCRIPT over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_g ( italic_Y ) ] |
(ii)𝔼𝒮~n[(f~n(Y),h(X))]+𝒮~n(𝒢,)𝑖𝑖subscript𝔼subscript~𝒮𝑛delimited-[]subscript~𝑓𝑛𝑌𝑋subscriptsubscript~𝒮𝑛subscript𝒢\displaystyle\overset{(ii)}{\leq}\mathbb{E}_{\tilde{\mathscr{S}}_{n}}[\ell(%\tilde{f}_{n}(Y),h(X))]+{\mathfrak{R}}_{\tilde{\mathscr{S}}_{n}}(\mathcal{G}_{%\ell,\mathcal{F}})start_OVERACCENT ( italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG blackboard_E start_POSTSUBSCRIPT over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_ℓ ( over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_Y ) , italic_h ( italic_X ) ) ] + fraktur_R start_POSTSUBSCRIPT over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT )
=R,f~n(𝒮~n)+𝒮~n(𝒢,),absentsubscript𝑅subscript~𝑓𝑛subscript~𝒮𝑛subscriptsubscript~𝒮𝑛subscript𝒢\displaystyle=R_{\ell,\tilde{f}_{n}}(\tilde{\mathscr{S}}_{n})+{\mathfrak{R}}_{%\tilde{\mathscr{S}}_{n}}(\mathcal{G}_{\ell,\mathcal{F}}),= italic_R start_POSTSUBSCRIPT roman_ℓ , over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) + fraktur_R start_POSTSUBSCRIPT over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT ) ,(21)

where inequality (i)𝑖(i)( italic_i ) follows from the definition of 𝒢,subscript𝒢\mathcal{G}_{\ell,\mathcal{F}}caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT and (ii)𝑖𝑖(ii)( italic_i italic_i ) from the standard symmetrization argument(Devroye etal., 2013; Mohri etal., 2018) for Radamacher complexity.

See 4.3

Proof.

As mentioned in the text \mathcal{F}caligraphic_F might be too large function class and we would like to utilize the restricted function class gsuperscriptsuperscript𝑔\mathcal{F}^{g^{\prime}}caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT. The previous derivation would go through using this restricted class and we will obtain the Rademacher complexity of 𝒮~n(𝒢,g)subscriptsubscript~𝒮𝑛subscript𝒢superscriptsuperscript𝑔{\mathfrak{R}}_{\tilde{\mathscr{S}}_{n}}(\mathcal{G}_{\ell,\mathcal{F}^{g^{%\prime}}})fraktur_R start_POSTSUBSCRIPT over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) instead.To compare the two Rademacher complexity, observe that

𝒮~n(𝒢,g)subscriptsubscript~𝒮𝑛subscript𝒢superscriptsuperscript𝑔\displaystyle{\mathfrak{R}}_{\tilde{\mathscr{S}}_{n}}(\mathcal{G}_{\ell,%\mathcal{F}^{g^{\prime}}})fraktur_R start_POSTSUBSCRIPT over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT )=1n𝔼𝝈|supλ𝒢,giσiλ(yi)|absent1𝑛subscript𝔼𝝈subscriptsupremum𝜆subscript𝒢superscriptsuperscript𝑔subscript𝑖subscript𝜎𝑖𝜆subscript𝑦𝑖\displaystyle=\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup_{\lambda\in%\mathcal{G}_{\ell,\mathcal{F}^{g^{\prime}}}}\sum_{i}\sigma_{i}\lambda(y_{i})\right|= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_σ end_POSTSUBSCRIPT | roman_sup start_POSTSUBSCRIPT italic_λ ∈ caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_λ ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) |
=(i)1n𝔼𝝈|suphiσi(hg(y),g(y))|𝑖1𝑛subscript𝔼𝝈subscriptsupremumsubscript𝑖subscript𝜎𝑖superscript𝑔𝑦𝑔𝑦\displaystyle\overset{(i)}{=}\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup_{h%\in\mathcal{H}}\sum_{i}\sigma_{i}\ell(h\circ g^{\prime}(y),g(y))\right|start_OVERACCENT ( italic_i ) end_OVERACCENT start_ARG = end_ARG divide start_ARG 1 end_ARG start_ARG italic_n end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_σ end_POSTSUBSCRIPT | roman_sup start_POSTSUBSCRIPT italic_h ∈ caligraphic_H end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_ℓ ( italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_y ) , italic_g ( italic_y ) ) |
=1n𝔼𝝈|supfgiσi(f(y),g(y))|absent1𝑛subscript𝔼𝝈subscriptsupremum𝑓superscriptsuperscript𝑔subscript𝑖subscript𝜎𝑖𝑓𝑦𝑔𝑦\displaystyle=\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup_{f\in\mathcal{F}^{%g^{\prime}}}\sum_{i}\sigma_{i}\ell(f(y),g(y))\right|= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_σ end_POSTSUBSCRIPT | roman_sup start_POSTSUBSCRIPT italic_f ∈ caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_ℓ ( italic_f ( italic_y ) , italic_g ( italic_y ) ) |
(ii)1n𝔼𝝈|supfiσi(f(y),g(y))|𝑖𝑖1𝑛subscript𝔼𝝈subscriptsupremum𝑓subscript𝑖subscript𝜎𝑖𝑓𝑦𝑔𝑦\displaystyle\overset{(ii)}{\leq}\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup%_{f\in\mathcal{F}}\sum_{i}\sigma_{i}\ell(f(y),g(y))\right|start_OVERACCENT ( italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG divide start_ARG 1 end_ARG start_ARG italic_n end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_σ end_POSTSUBSCRIPT | roman_sup start_POSTSUBSCRIPT italic_f ∈ caligraphic_F end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_ℓ ( italic_f ( italic_y ) , italic_g ( italic_y ) ) |
=1n𝔼𝝈|supλ𝒢,iσiλ(yi)|absent1𝑛subscript𝔼𝝈subscriptsupremum𝜆subscript𝒢subscript𝑖subscript𝜎𝑖𝜆subscript𝑦𝑖\displaystyle=\frac{1}{n}\mathbb{E}_{\bm{\sigma}}\left|\sup_{\lambda\in%\mathcal{G}_{\ell,\mathcal{F}}}\sum_{i}\sigma_{i}\lambda(y_{i})\right|= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_σ end_POSTSUBSCRIPT | roman_sup start_POSTSUBSCRIPT italic_λ ∈ caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_λ ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) |
=𝒮~n(𝒢,),absentsubscriptsubscript~𝒮𝑛subscript𝒢\displaystyle={\mathfrak{R}}_{\tilde{\mathscr{S}}_{n}}(\mathcal{G}_{\ell,%\mathcal{F}}),= fraktur_R start_POSTSUBSCRIPT over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F end_POSTSUBSCRIPT ) ,(22)

where (i)𝑖(i)( italic_i ) follows from definition of 𝒢,gsubscript𝒢superscriptsuperscript𝑔\mathcal{G}_{\ell,\mathcal{F}^{g^{\prime}}}caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT and gsuperscriptsuperscript𝑔\mathcal{F}^{g^{\prime}}caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT; and (ii)𝑖𝑖(ii)( italic_i italic_i ) holds because gsuperscriptsuperscript𝑔\mathcal{F}^{g^{\prime}}\subset\mathcal{F}caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ⊂ caligraphic_F.

Now, the standard concentration results for empirical Rademacher complexity implies that, with probability at least 1δ1𝛿1-\delta1 - italic_δ, we have the following.

𝒮~n(𝒢,g)subscriptsubscript~𝒮𝑛subscript𝒢superscriptsuperscript𝑔\displaystyle{\mathfrak{R}}_{\tilde{\mathscr{S}}_{n}}(\mathcal{G}_{\ell,%\mathcal{F}^{g^{\prime}}})fraktur_R start_POSTSUBSCRIPT over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT )𝔼𝒮~n𝒟~Yn[𝒮~n(𝒢,g)]+𝒪(log(1/δ)n)absentsubscript𝔼similar-tosubscript~𝒮𝑛subscriptsuperscript~𝒟tensor-productabsent𝑛𝑌delimited-[]subscriptsubscript~𝒮𝑛subscript𝒢superscriptsuperscript𝑔𝒪1𝛿𝑛\displaystyle\leq\mathbb{E}_{\tilde{\mathscr{S}}_{n}\sim\tilde{\mathscr{D}}^{%\otimes n}_{Y}}\left[\mathfrak{R}_{\tilde{\mathscr{S}}_{n}}(\mathcal{G}_{\ell,%\mathcal{F}^{g^{\prime}}})\right]+\mathcal{O}\Big{(}\sqrt{{\frac{\log(1/\delta%)}{n}}}\Big{)}≤ blackboard_E start_POSTSUBSCRIPT over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∼ over~ start_ARG script_D end_ARG start_POSTSUPERSCRIPT ⊗ italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ fraktur_R start_POSTSUBSCRIPT over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) ] + caligraphic_O ( square-root start_ARG divide start_ARG roman_log ( 1 / italic_δ ) end_ARG start_ARG italic_n end_ARG end_ARG )(23)
=~n(𝒢,g)+𝒪(log(1/δ)n).absentsubscript~𝑛subscript𝒢superscriptsuperscript𝑔𝒪1𝛿𝑛\displaystyle=\tilde{\mathfrak{R}}_{n}(\mathcal{G}_{\ell,\mathcal{F}^{g^{%\prime}}})+\mathcal{O}\Big{(}\sqrt{{\frac{\log(1/\delta)}{n}}}\Big{)}.= over~ start_ARG fraktur_R end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) + caligraphic_O ( square-root start_ARG divide start_ARG roman_log ( 1 / italic_δ ) end_ARG start_ARG italic_n end_ARG end_ARG ) .(24)

Combining results from Eq.16, A, and 23, we obtain that with probability at least 1δ1𝛿1-\delta1 - italic_δ,

R,f(𝒟Y)R,f~n(𝒮~n)+𝒲(𝒟Y,𝒟~Y)Lu+~n(𝒢,g)+𝒪(log(1/δ)n)subscript𝑅𝑓subscript𝒟𝑌subscript𝑅subscript~𝑓𝑛subscript~𝒮𝑛subscript𝒲subscript𝒟𝑌subscript~𝒟𝑌absent𝐿norm𝑢subscript~𝑛subscript𝒢superscriptsuperscript𝑔𝒪1𝛿𝑛R_{\ell,f}(\mathscr{D}_{Y})\leq R_{\ell,\tilde{f}_{n}}(\tilde{\mathscr{S}}_{n}%)+\underbrace{\mathcal{W}(\mathscr{D}_{Y},\tilde{\mathscr{D}}_{Y})}_{\leq L\|u%\|}+\tilde{\mathfrak{R}}_{n}(\mathcal{G}_{\ell,\mathcal{F}^{g^{\prime}}})+%\mathcal{O}\Big{(}\sqrt{{\frac{\log(1/\delta)}{n}}}\Big{)}italic_R start_POSTSUBSCRIPT roman_ℓ , italic_f end_POSTSUBSCRIPT ( script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ) ≤ italic_R start_POSTSUBSCRIPT roman_ℓ , over~ start_ARG italic_f end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG script_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) + under⏟ start_ARG caligraphic_W ( script_D start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT , over~ start_ARG script_D end_ARG start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ) end_ARG start_POSTSUBSCRIPT ≤ italic_L ∥ italic_u ∥ end_POSTSUBSCRIPT + over~ start_ARG fraktur_R end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( caligraphic_G start_POSTSUBSCRIPT roman_ℓ , caligraphic_F start_POSTSUPERSCRIPT italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ) + caligraphic_O ( square-root start_ARG divide start_ARG roman_log ( 1 / italic_δ ) end_ARG start_ARG italic_n end_ARG end_ARG )(25)

Improving Dense Retriever Training with Corrector Networks (5)

Appendix B Experiments

B.1 Experimental Details

Training Details We train all models, the dual-encoders and the corrector model, jointly using Adam (Kingma & Ba, 2014). We implement the training procedure using stop-gradients so that the corrector model loss only changes the corrector model parameters and dual-encoder loss the dual-encoder ones. We form the subset of targets for the truncated softmax, S(𝒴)x𝑆subscript𝒴𝑥S(\mathcal{Y})_{x}italic_S ( caligraphic_Y ) start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT, using the top-64 closest targets to the given query according to a particular training procedure’s buffer and 64 targets chosen uniformly at random. We use a minibatch size of 128 examples and share the truncated softmax targets across all examples in the minibatch mb𝑚𝑏mbitalic_m italic_b , e.g., xmbS(𝒴)xsubscript𝑥𝑚𝑏𝑆subscript𝒴𝑥\bigcup_{x\in mb}S(\mathcal{Y})_{x}⋃ start_POSTSUBSCRIPT italic_x ∈ italic_m italic_b end_POSTSUBSCRIPT italic_S ( caligraphic_Y ) start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT. We use 40K steps for retrieval training and 20K steps for RLM training. We combine the task losses and corrector network loss together. We experimented with a weight parameter applied to the corrector network. We use a weight value of 10.0.

B.2 Additional Dense Retrieval Results

In Table4, we report performance on MSMarco using T5-base as the encoder. Here, with fewer targets, Stochastic Negative Mining provides a better approximation as a larger fraction of targets is re-encoded. Our method is still able to nearly match the performance of the exhaustive approach. We are able to achieve such results without having to re-embed the buffer.

Using Accelerator Memory to Store the Buffer In these experiments, we store the buffer of targets on the accelerator, making implementation of our approach training extremely easy. However, it could be the case that not all targets can fit into a buffer on accelerator memory. In such settings, our approach could still be used in the following ways: (1) subsample targets randomly (perhaps changing the subset periodically) to fit on device memory akin to a combination of our corrector approach stochastic negative mining, which would require no re-encoding of targets, or (2) use our approach to re-rank stale representations initially retrieved from CPU memory.

PerformanceNum. Re-EmbedR@1R@5R@10R@100
Stale010.1127.7036.3363.69
SNM20M18.1843.4854.6882.18
Dynnibal8M18.2343.1554.5682.24
Target Corrector017.0740.7851.5679.29
Exhaustive352M18.1844.9755.5883.69

Comparisons to 2-Round Training

Several recent works such as (Qu etal., 2021) which addresses difficulties of training dense retrieval models proposes to train in 2 stages. First all targets are encoded (using random or pre-trained model). Then the model is trained for one half of the desired iterations. Then the new model’s parameters are used to re-encode the targets a single time. Then the model is trained for the remaining steps using these re-encoded targets. We compare this approch with corrector networks in Table 5. We see that when using GTR-base, the performance for all methods is quite similar (with corrector networks and exhaustive re-encoding slightly outperforming). When T5-base is used though, we find the performance of corrector networks and exhaustive re-encoding to notably out-perform the 2-step procedure. We attribute this to GTR being a better initialization for the model. In this case we would expect its parameters (and therefore its target embeddings) to change less from pre-training to fine-tuning, meaning that there is less embedding drift and therefore less bias when using the 2-step procedure.

MethodBaseR@1R@5R@10R@20R@100
Two RoundT529.5053.4062.4970.6480.94
CorrectorT536.6559.2568.0673.7183.13
ExhaustiveT537.3460.4268.7074.7683.41
Two RoundGTR49.0670.0676.7681.1787.95
CorrectorGTR49.6170.7277.0482.3388.28
ExhaustiveGTR50.3071.5578.1282.8388.59

Comparisons with and without uniform negatives

In our main experiments (as stated in Appendix B.1) we train with hard negatives and uniform negatives. Initial experiments showed that adding uniform negatives lead to improved performance in some settings. We provide some additional results ablating this choice using exhaustive re-encoding. These can be found in Table 6. We can see that this choice provides negligible improvement on the reported benchmarks (although we believe its worth trying in other settings).

MethodBaseR@1R@100
With UniformT524.7079.94
Without UniformT524.8779.82
With UniformGTR36.2991.73
Without UniformGTR36.5391.70

B.3 Retrieve and Read

Note that in this setting we do not share the subset of targets S(𝒴)𝑆𝒴S(\mathcal{Y})italic_S ( caligraphic_Y ) across the examples in the batch, nor do we use targets sampled uniformly at random.

B.4 Beyond Stale Representations: Approximating Large Models with Small Models

In this experiment, we focus on sampling in isolation. We sample a batch of input points and we measure the ability of our method to approximate one dual encoder model with another. In particular, we study a case where we approximate a large dual encoder with a small model. We approximate the GTR large model (Ni etal., 2021) (e.g., g()𝑔g(\cdot)italic_g ( ⋅ )) with the GTR small model(e.g., g()superscript𝑔g^{\prime}(\cdot)italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( ⋅ )). In Table7, we report nearest neighbor precision, i.e., measuring the overlap in the top-K neighbors from the large model’s neighbors at 10, 20, and 100 on the dataset Arguana (Wachsmuth etal., 2018) from the BEIR benchmark (Thakur etal., 2021). We use 32 samples for each query to train the correction model. We find that overlap amongst smaller K seems to be better aligned using our method.

PerformanceP@10P@20P@100
g()superscript𝑔g^{\prime}(\cdot)italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( ⋅ )67.5767.7357.57
hg()superscript𝑔h\circ g^{\prime}(\cdot)italic_h ∘ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( ⋅ )76.8773.3253.55
Improving Dense Retriever Training with Corrector Networks (2025)
Top Articles
Latest Posts
Recommended Articles
Article information

Author: Rev. Leonie Wyman

Last Updated:

Views: 5263

Rating: 4.9 / 5 (59 voted)

Reviews: 82% of readers found this page helpful

Author information

Name: Rev. Leonie Wyman

Birthday: 1993-07-01

Address: Suite 763 6272 Lang Bypass, New Xochitlport, VT 72704-3308

Phone: +22014484519944

Job: Banking Officer

Hobby: Sailing, Gaming, Basketball, Calligraphy, Mycology, Astronomy, Juggling

Introduction: My name is Rev. Leonie Wyman, I am a colorful, tasty, splendid, fair, witty, gorgeous, splendid person who loves writing and wants to share my knowledge and understanding with you.