Summary. When generating outputs with language models, Retro searches and retrieves tokens from a database based on similarities with its input in the embedding space. Retro encodes and then incorporate the retrieved text into the intermediate representations of the language model via cross attention. The result is potentially to increase memory capacity of language model without significantly increasing the number of parameters.
- Author claims the cost of retrieval is roughly 10ms per chunk. How does this number scale as the batch size scales up? Usually in DL one can expect sub linear compute time increase when increasing the batch sizes especially when the batch size is small. But it seems like kNN search will not benefit from increased batch size and compute time will scale linearly with batch size instead.
- An interesting question left unanswered is how much more efficient is retrieval based external memory v.s. internal memory that encodes information with NN parameters. Concretely, to achieve the same performance, Retro allows LM to scale down the # params and thus the amount of compute (e.g., in terms of floating point instructions), but increases the amount of compute due to the kNN search. How does the saving compare with the extra cost? What's the ratio here?
- Can we hide the latency of kNN search somewhere? When generating text auto-regressively, we can only start retrieving when the previous chunk of text is available, and the retrieved contexts are immediately used for generating the current chunk. KNN search thus lies on the critical path (though it’s overhead may be low). One way to alleviate this overhead is to use a slightly “stale” retrieved context: in Retro, chunk K is allowed to attend to context retrieved for chunk K-1. We can restrict it to only attend to context retrieved for chunk K-2, and overlap the search for context of chunk K-1 with the generation of chunk K, which now only requires the context of K-2. This may however hurt model quality.
- Could retrival enhanced models improve the truthfulness and interpretability of large language models?
- Design questions: why is a transformer encoder needed to encode the retrieved neighbor embeddings? What happens if the BERT embeddings are used as is for chunked cross attention (CCA) module?
- For the same transformer encoder used to encode retrieved neighbors of chunk C, why attending over the activation of the chunk C? What happens if their embeddings are obtained without information of current chunk? This step seems weird because they are combined in the cross attention module anyways.
- TIL masked LMs are good for generating embeddings and causal LMs are good for generating after prompts.