AbstRaL: Augmenting LLMs' Reasoning by Reinforcing Abstract Thinking
AuthorsSilin Gao**, Antoine Bosselut†, Samy Bengio, Emmanuel Abbe
AuthorsSilin Gao**, Antoine Bosselut†, Samy Bengio, Emmanuel Abbe
Recent studies have shown that large language models (LLMs), especially smaller ones, often lack robustness in their reasoning. I.e., they tend to experience performance drops when faced with distribution shifts, such as changes to numerical or nominal variables, or insertions of distracting clauses. A possible strategy to address this involves generating synthetic data to further "instantiate" reasoning problems on potential variations. In contrast, our approach focuses on "abstracting" reasoning problems. This not only helps counteract distribution shifts but also facilitates the connection to symbolic tools for deriving solutions. We find that this abstraction process is better acquired through reinforcement learning (RL) than just supervised fine-tuning, which often fails to produce faithful abstractions. Our method, AbstRaL---which promotes abstract reasoning in LLMs using RL on granular abstraction data---significantly mitigates performance degradation on recent GSM perturbation benchmarks.
Figure 1: Our AbstRaction Learning (AbstRaL) method effectively improves the reasoning robustness of LLMs, especially facing the variations of relevant input conditions and the interference of distracting conditions. We present average accuracy of all our tested LLMs on different GSM-Plus testing sets, including the original GSM8K testing set (Original Reasoning Problem), the testing sets with numerical variations (Vary Input Conditions), averaged across three portions (digit expansion, integer-decimal-fraction conversion and numerical substitution), the testing set with problem rephrasing (Vary Problem Contexts) and with distractor insertion (Add Distracting Conditions).
Figure 2: Learning strategies to improve reasoning robustness with respect to distribution shifts. (a) Augmenting the amount of learning data by synthesizing more reasoning instances. (b) Directly learning to construct the underlying abstraction based on the input, including: (b1) condition recognition, (b2) abstract reasoning, (b3) abstraction retrieval and (b4) symbolic derivation.
October 11, 2024research area Speech and Natural Language Processing
May 1, 2024research area Methods and Algorithmsconference ICLR