View publication

Recent studies have shown that large language models (LLMs), especially smaller ones, often lack robustness in their reasoning. I.e., they tend to experience performance drops when faced with distribution shifts, such as changes to numerical or nominal variables, or insertions of distracting clauses. A possible strategy to address this involves generating synthetic data to further "instantiate" reasoning problems on potential variations. In contrast, our approach focuses on "abstracting" reasoning problems. This not only helps counteract distribution shifts but also facilitates the connection to symbolic tools for deriving solutions. We find that this abstraction process is better acquired through reinforcement learning (RL) than just supervised fine-tuning, which often fails to produce faithful abstractions. Our method, AbstRaL---which promotes abstract reasoning in LLMs using RL on granular abstraction data---significantly mitigates performance degradation on recent GSM perturbation benchmarks.

Figure 1: Our AbstRaction Learning (AbstRaL) method effectively improves the reasoning robustness of LLMs, especially facing the variations of relevant input conditions and the interference of distracting conditions. We present average accuracy of all our tested LLMs on different GSM-Plus testing sets, including the original GSM8K testing set (Original Reasoning Problem), the testing sets with numerical variations (Vary Input Conditions), averaged across three portions (digit expansion, integer-decimal-fraction conversion and numerical substitution), the testing set with problem rephrasing (Vary Problem Contexts) and with distractor insertion (Add Distracting Conditions).

Figure 2: Learning strategies to improve reasoning robustness with respect to distribution shifts. (a) Augmenting the amount of learning data by synthesizing more reasoning instances. (b) Directly learning to construct the underlying abstraction based on the input, including: (b1) condition recognition, (b2) abstract reasoning, (b3) abstraction retrieval and (b4) symbolic derivation.

Related readings and updates.

Recent advancements in Large Language Models (LLMs) have sparked interest in their formal reasoning capabilities, particularly in mathematics. The GSM8K benchmark is widely used to assess the mathematical reasoning of models on grade-school-level questions. While the performance of LLMs on GSM8K has significantly improved in recent years, it remains unclear whether their mathematical reasoning capabilities have genuinely advanced, raising…
Read more
We investigate the capabilities of transformer models on relational reasoning tasks. In these tasks, models are trained on a set of strings encoding abstract relations, and are then tested out-of-distribution on data that contains symbols that did not appear in the training dataset. We prove that for any relational reasoning task in a large family of tasks, transformers learn the abstract relations and generalize to the test set when trained by…
Read more