Podchaser Logo
Home
Teaching Large Language Models to Reason with Reinforcement Learning with Alex Havrilla

Teaching Large Language Models to Reason with Reinforcement Learning with Alex Havrilla

Released Tuesday, 16th April 2024
Good episode? Give it some love!
Teaching Large Language Models to Reason with Reinforcement Learning with Alex Havrilla

Teaching Large Language Models to Reason with Reinforcement Learning with Alex Havrilla

Teaching Large Language Models to Reason with Reinforcement Learning with Alex Havrilla

Teaching Large Language Models to Reason with Reinforcement Learning with Alex Havrilla

Tuesday, 16th April 2024
Good episode? Give it some love!
Rate Episode

Episode Transcript

Transcripts are displayed as originally observed. Some content, including advertisements may have changed.

Use Ctrl + F to search

0:06

Or everyone welcome to another episode of The

0:08

Trauma our podcast I am Your Hosts Sam

0:10

Sharing to and and today I'm excited to

0:12

be here with Alex Avila Alex is a

0:14

Phd student at Georgia Tech. Before we get

0:16

going, be sure to take a moment to

0:19

hit that subscribe button wherever you're listening to

0:21

today. So. Alex. Welcome to the

0:23

park s thanks for Addington Excited be here!

0:25

I'm excited to dig into our

0:28

conversation as well. We're gonna be

0:30

talking about your research. Weird centers

0:32

around the idea of applying reinforcement

0:34

learning to reasoning and some of

0:36

the downstream palms you ran into

0:38

in the process. But before we

0:40

dive into year specific research, wanted

0:42

to introduce yourself a little bit

0:44

to our audience, share how you

0:46

came to work in the field

0:49

and will start their. Yeah.

0:51

So like you mentioned I am the

0:53

third your phd student at Georgia Tech.

0:55

I of them are the mathematics department

0:57

my advisor does neural network learning theory

0:59

but like two years ago I got

1:01

into open source. So I'm research through

1:03

the open source community and since then

1:05

have been very interested in are fine

1:07

tuning for Allah Lands and specifically with

1:09

a focus on trying to improve their

1:11

reasoning capability. Tell. Us a

1:13

little bit about how you think about

1:15

your research broadly. Broadly. Speaking,

1:17

I am kind of dividing my

1:19

intention and to roughly two parts.

1:21

So the first part is work

1:23

on this more theoretical aspect. Like

1:25

I mentioned so Neural Network Learning

1:27

Theory trying to be no say

1:29

the article statements about given and

1:31

certain number of training examples: How

1:34

quickly can your model attain some

1:36

minimum generalization error? How. Large

1:38

Your network need to be in theory

1:40

to approximate some target function. and

1:43

then on them or applied alone sides.

1:45

My research is kind of a little

1:47

more broad a. I. Focused on

1:49

are all fine tuning. I was

1:52

responsible or like. A response or

1:54

with some other very talented people for the Velma

1:56

to the Terror Alex Library for open source or

1:58

allege off when that allowed us. The scene at

2:00

large scale our moral fine tuning experiments for

2:02

all a tough which kind of pave the

2:04

way for doing are still experiments with reasoning

2:07

which was always my read some motivation and

2:09

when you. Talk. About and

2:11

think about our own for reasoning.

2:13

How much of that is Rl

2:15

Hf versus the more broad application

2:18

of our our related methods to

2:20

improving the reasoning capabilities as line

2:22

with models. There's no humans

2:25

he backs of this is entirely automated. You

2:27

can like get whatever signals are interested in

2:29

from some ground truth label so maybe too

2:31

distant. Did very distance human supervision but you

2:33

know in the training that there's no supervision

2:35

but certainly I think humans the of act

2:37

is also very valuable way of than you

2:39

know Improving Llm reasoning with are also won.

2:41

A great example of that is this open

2:44

I let's verify step by step paper. Which.

2:46

Trains this process space award model to

2:49

give very granular stuff level feedback to

2:51

Aloe and saw the math problems. lingering.

2:54

On Rl a tough for a

2:56

moment. you me think about that

2:58

as kind of the composition of

3:00

Rl and a Sas. Ah, I'm.

3:02

He often think that we get wrapped up

3:05

in Rl and really the heavy lifting is

3:07

done by the Hst. Agree. Disagree. How do

3:09

you think? I think. It's for

3:11

sure true that the majority of the improvements that

3:13

you see in today's Allah Lambs are due to

3:15

the humans. He had back parts and I think

3:18

that was evidence even if you look as far

3:20

back as the instructs Cbt paper like. They.

3:22

Had these comparisons where they have this

3:24

like and an intermediate model which they

3:26

called the feed Me model. This is

3:29

texted Nc to and.that model is already

3:31

like miles better than Cpt Three. I

3:33

mean I was just much muffler. On.

3:35

And than like. It. Once you

3:38

do the are all Hf on top of that

3:40

like the whole ppl pipeline thing in improve somewhat

3:42

like there was a ten percent preference win rate

3:44

improvement over that. You. Know, feed me as

3:46

if t Baseline but this is you know, nothing

3:48

compared to any improvement he saw over the original

3:50

base model. Ah, so I think is. Yeah, it's

3:52

definitely true that the humans he packed part is

3:55

doing a lot of the thing. but on the

3:57

other hand, I think. You. Know the

3:59

hope that many people for Rl

4:01

myself included is that you know

4:03

if we really want to start

4:05

training these super human level systems.

4:08

Did. The you know the announced that we're going

4:11

to be able to provide high quality supervision is

4:13

gonna nationally decreases the complexity of the tasks we

4:15

attempt to solve with these systems gets bigger and

4:17

bigger. Continue. Along the

4:19

our our vector one

4:21

of the historical challenges

4:24

worse reinforcement learning is

4:26

it's data inefficiency. Do

4:28

you see. Us.

4:31

Solving. That by

4:33

leveraging can it's. Unique.

4:37

Properties. Of language models to

4:39

make our our more data for

4:41

send and use fall in this

4:43

particular context or is it more

4:45

about near tracking improvement in general

4:47

are our research and kind of

4:49

pulling at into. Application

4:52

Wes line was models.

4:54

To like you said, more classical or all.

4:57

yeah extremely simple. any sense like he might

4:59

have to round out wait around for like

5:01

you know hundreds of thousands of rollouts before

5:03

you see like any immersion paid your at

5:05

all but before your agent or anything. Not

5:07

the case with our for a while and

5:09

so like ah when you start you know

5:11

fine tuning these pre trained L A lamb's

5:13

ah they start reliably improving like basically immediately.

5:16

And on this is I think because of like

5:18

there a couple difference is here but the main

5:20

differences you know like when you're fine tuning this

5:22

aloe and you're starting in a either from some

5:25

supervised fine tune checkpoint or from you know some

5:27

like pretend checkpoint majority has a pretty good sense

5:29

of what is trying to do rights like there's

5:31

a very strong like warm starting bias which you

5:33

don't have a classical and that was like really

5:36

good for sample of it and see because you

5:38

don't have to wait around for them all the

5:40

kind of like. You. Know make random

5:42

moves in and see if those rare animals

5:45

are useful. League you can just you know

5:47

immediately start generating and the concept of reasoning

5:49

you know you can immediately start generating a

5:51

potential candidates loosens. Some

5:53

percentage of which will be cracked and you know

5:56

you can get a pretty good the rewards signal

5:58

from out immediately. And so

6:00

that means sample efficiency is pretty good too. Like actually, in

6:03

most of the experiments we did, we didn't require

6:05

more than 100,000 samples most of the time to

6:07

like, you know, really fully converge to the max

6:09

performance that we see. Yeah,

6:12

maybe with that in mind kind of

6:14

frame out, you

6:16

know, where we are with RL applied

6:18

to LLMs, what

6:21

are the, you're applying it to

6:23

reasoning. Are there other ways that

6:25

RL is promising in the application

6:28

to LLMs? And

6:31

kind of where are we broadly speaking?

6:35

Yeah, so of course, there's RLHF, like

6:37

to be talked about, super useful for,

6:39

you know, aligning the interactability

6:41

of the model with how humans expect

6:43

to interact with it. So

6:47

there's that. I think also there is a

6:49

lot of, you know, work going to be

6:52

being done and will continue to be done

6:54

integrating LLMs with, you know, various tool usage

6:56

and allowing them to, you know, act as

6:58

web agents, for example, browsing the internet, these

7:01

kinds of things. And there will definitely be,

7:03

you know, RL components there, which

7:06

like, you know, it gets, that's a very interactive process, right? Like

7:08

when you're learning a new tool for the first time, you have

7:10

to figure out what are the scenarios in which

7:12

should I use this tool, whatever, like,

7:14

you know, the preconditions for using this tool. And,

7:16

you know, anytime you have an interactive process that

7:18

is very naturally RL focused. Given

7:21

the cost of like training, large

7:24

language models and their complexity, like

7:26

is there an online component

7:29

that RL makes possible for LLMs?

7:32

How do we implement online learning in LLMs?

7:35

So first of all, online learning is still a

7:37

very difficult problem. So like this idea of, you

7:39

know, continual learning or like updating yourself as you

7:41

go, so quite difficult. I

7:44

mean, I think there are some naive solutions like,

7:46

you know, gather some data, fine tune the LLM

7:48

on the data. And

7:50

maybe those work okay, but it's kind of hacky, right?

7:53

Like it's not really the solution we're looking for. But

7:56

I think you bring up another good point, which is, and this is

7:58

one of the other, like one of the main things we're- investigating

8:00

in the paper is these different paradigms of

8:02

doing RL training. So like there's this online

8:04

on policy paradigm where you have some agent

8:06

in an environment, right, and you're rolling out

8:09

the agent and then you're

8:11

training the like policy, you're updating the

8:13

policy using the like data which has

8:15

just been collected by the agent. So

8:17

that's online and on policy. The agent's

8:19

doing exploration and you're updating the agent

8:22

using data which is generated. But

8:25

there are kind of like more off policy

8:27

versions of this where you can

8:29

generate lots of data and then you can,

8:31

the agent you're training and updating is like

8:33

updated on a kind of old or stale

8:36

data. So it's like

8:38

put concrete names behind this proximal

8:40

policy optimization. PPO is

8:42

an example of like a you know a

8:44

mostly online on policy algorithm and

8:46

that's one of the algorithms we looked at

8:48

in the paper. Expert iteration, you have

8:51

a question and you sample the model

8:53

many times on the question for many

8:55

different solutions and then you fine-tune

8:57

the model on the correct solutions. That

9:00

is much more off policy because you

9:02

know you're generating lots of data and

9:04

then you're you know fine-tuning the model

9:07

on like trajectories which you know

9:10

may come from different points and like

9:12

the training cycle. So

9:15

like and usually like when you're

9:17

doing this type of training in the

9:19

classical RL setting on policy algorithms such

9:21

as PPO are, they're not more sample

9:23

efficient technically but you

9:25

need less iterations of like training

9:28

to fully converge because you're exploring

9:30

this complex state space and somehow

9:33

the updates you make to the policy are

9:35

like the most useful ones you're making in

9:37

the context of solving the problem and off

9:39

policy is somehow like converges slower in

9:41

comparison. In classical RL but this

9:43

is not at all what we found in

9:45

when we were comparing these methodologies in

9:48

like RL fine-tuning for LLM. So when

9:50

you look at the sample complexity or

9:52

the number of samples you need to

9:54

reach a certain level of performance for

9:56

PPO versus expediteration it's roughly

9:58

the same to be honest despite the that

10:01

extra duration is much more off policy and

10:03

potentially generating less useful data, which

10:06

I think a priori is a bit of a surprise because

10:08

this is, you know, again, not the case at all in

10:10

classical RL. So these two methods

10:12

have roughly the same sample complexity. And

10:15

like I mentioned before, like the overall sample

10:18

complexity you need for these models to converge

10:20

is like fairly, like pretty small, like only

10:22

a hundred thousand samples. So like somehow this

10:24

is an indication that like

10:26

definitely the models, you know, are learning

10:29

something and are like uncovering interesting behavior.

10:31

But somehow like, you know, they're not

10:33

really accessing, you know,

10:36

they're not somehow diversely exploring in the

10:38

way that we might want them to.

10:42

Because otherwise if they weren't accessing, you

10:44

know, like new types of solutions and

10:46

new modes of behavior, you probably wouldn't

10:48

see such like fast conversions and like

10:50

such similar sample complexities between these methods.

10:53

Let's maybe punch into the paper

10:55

so that we can talk about some

10:57

of the broad RL concepts and their

11:00

applicability kind of more concretely. So

11:02

the paper is again

11:04

teaching large language models to reason

11:07

with reinforcement learning. Talk

11:09

a little bit about your motivations

11:12

with the paper. So

11:14

I already kind of touched on it, but you

11:17

know, in the RLHF space there are

11:19

many different algorithms now that are popular, right? So

11:21

at the beginning there's PPO and PPO and Yoba

11:23

Soda. And it was unclear whether you

11:26

could use other algorithms. But now,

11:28

you know, on like April 2024 it's

11:30

clear that, you know, not only are there

11:33

many competitors, but they're basically just as good,

11:35

right? So there was a recent paper from

11:37

Cohere which shows you that reinforce is basically

11:39

as good as PPO in this context. DPO

11:42

is obviously very popular now and does no

11:44

exploration at all, which

11:47

is also maybe another sign that

11:49

something is not quite right here. But

11:52

it's, you know, basically just as competitive

11:54

as PPO, which does lots

11:56

of exploration. But, you

11:58

know, when I was investigating these problems… This is not

12:00

clear like it was not clear how did these

12:02

algorithms compare? Like where are the benefits of using

12:05

one hour going on for it versus another offer

12:07

you know during this fine tuning as others as

12:09

really our goal You know like. We're. Going

12:11

to take all these algorithms and we're going

12:14

to take a bunch of different ideas are

12:16

in Rl and we're just gonna see what

12:18

works best when you're finding on trial here

12:20

specifically human feedback or. Something

12:23

else. This was not early chaff. So in the

12:25

context of reasoning, usually you have a question and

12:27

then you an answer. And. That

12:29

answer has like a goal final as you're trying

12:31

to reach. So like you know, maybe it's some

12:34

math were a problem and you're trying to get

12:36

some integer solution to the mouth were problem? I

12:38

don't Yeah, we can get a reward is the

12:40

most basic sense. We can provide a reward feedback

12:42

based on did you get this cracks final answer?

12:45

did you not? And that's how you can start

12:47

to do the aral. What's. The

12:49

relationship between the. The.

12:51

Policy and Line was model

12:54

itself. Okay, that's it.

12:56

That's another good question. So in this case,

12:59

The policy is exactly the language

13:01

model. So the language model generates

13:03

every single action. Ah I'm

13:05

and then you know is supposed to come to

13:07

the cracks. Final answer. But. This

13:10

I think this is a good question

13:12

because there there are situations in which

13:14

you know that the policy is not

13:16

necessarily just the model. So for example,

13:18

you can imagine combining a calculator with

13:20

an Aloe under solve math problems. We.

13:22

Didn't have to do it here because. We.

13:24

Were working with mama to base models i'm

13:26

a mom is pretty good at arithmetic. Early

13:28

subtypes of are authentic we were doing. There

13:31

was no need for calculator, but you can

13:33

imagine combining a basic tool at a calculator

13:35

with the Llm and that forums Djelic a

13:37

more generalized policy which is the combination of

13:39

our on and some tool. Ah so yeah,

13:41

just like he insisted. Gonna drive home the

13:43

point that the policy doesn't necessarily have to

13:45

be the hello on. This. The

13:47

policy is the line was mater. what

13:49

does it mean Optimize The line was

13:52

model where you optimizing or are you

13:54

what are your lover's are they like

13:56

things like temperature and things they go

13:58

into a prom, dates and. Stunning that

14:00

they're not the model itself. In our case it

14:02

is the model. So like we are literally we're

14:04

literally like changing the weight said the model. Depending

14:07

on you know the type of reward that's you

14:09

know he is getting back. Like did you get

14:11

this questioning cracked? If yes you know like change.

14:13

agree it's in a way to you know I

14:16

just for this city got it wrong. Okay maybe

14:18

change your dreams in a different one. Got.

14:21

Yet I was imagining the third model

14:23

was embedded in some larger system in

14:25

your tweet something about the larger system

14:27

as opposed to the the model itself

14:30

as he he our directly finding them

14:32

on. Got it.

14:34

Ah, and so. Ah,

14:37

you alluded to.

14:40

One. Of the big results as

14:42

you saw that the algorithms all performed

14:44

largely the same his era next level

14:46

of new on so detail there. Are

14:49

some guy? I think so? So let me.

14:53

Expand on that. So we we went

14:55

in expecting you know ppl to do

14:57

quite well like get the best performance

14:59

and have pretty good sample efficiency. Ah,

15:01

because like I said, you know it's

15:03

more on policy to be more you

15:05

know, This. Isn't in some senses

15:07

to converge faster? Ah,

15:10

I'm that son in. this is simply not

15:12

what we saw. So. Like

15:14

I said, aspiration for example, where

15:16

you simply fine tune on the

15:18

correct answers. Performs roughly the

15:20

same ah I'm in are set up.

15:23

Technically it was flight somewhat less sample

15:25

efficient, but we did some places with

15:27

Demonstrates A gets to be very concrete.

15:30

That. Way be implemented Exploration as you know. We

15:32

sampled the model like ninety six times for

15:34

a question or something and then in took

15:37

all the answers which were cracked and fine

15:39

tuned the model on those and iterative that.

15:41

Ah, But then we did some. a

15:43

plane lands in this case is determined by human

15:46

feedback. Try to act is like for each answer

15:48

you have like a ground truth. Correct

15:50

final answer. And. That has

15:52

been labeled by human the an

15:54

Israeli uses you know doing. Tix

15:57

comparison or something automated to

15:59

make. The correct answers their

16:01

cigarette are but you do have the

16:03

correct answer because they're ultimately mass near

16:05

St. Sounds like okay when they all

16:08

simple arithmetic or that does one category

16:10

of reasoning problems or to yeah they

16:12

were. There are mostly simple arithmetic. We

16:14

also considered some like common sense question

16:16

answering but that the majority that was

16:19

arithmetic. Like. Gs A make a style

16:21

questions. Ah but just

16:23

like the of finish like describing this had

16:25

a very concretely select for exploration we have

16:28

like ninety six samples for question. That.

16:30

We fine tuned on all the correct answers.

16:33

Ah, But. It turns

16:35

out like ninety six Questions in the

16:38

ninety six Answers for Questions kind of

16:40

overkill. Like you can reduce that. The

16:42

for answers for question get nearly the

16:44

same performance as you had before. Alec

16:47

after each iteration and then you can match

16:49

the sample complexity of Ppl and and you

16:51

still be Pps performance. Sorry.

16:54

I in an alley and we investigated

16:56

some other setups as well. So investigated

16:58

Die in the offline case where you

17:00

don't to exploit any aspiration. I'm. There's.

17:03

This idea of like conditional a fine tuning

17:05

where you can like label each step as

17:08

like correct or incorrect using. Good.

17:10

Or bad tokens. And what? Like

17:12

what we thought it, you know why. Why

17:14

was this the case? Like ah, why is

17:17

everything performing roughly roughly the same? Ah

17:19

because we we also tried some other

17:21

ideas as well. so like their ideas

17:23

ideas of training overboard models in the

17:25

literature so the such as this idea

17:27

than outcome based award model where you

17:29

can predict the correctness of the final

17:31

answer. ah I'm and like somehow this

17:33

might seen a smooth the type of

17:35

reward that the model is getting as

17:37

like previously you only get a reward

17:39

of plus one. If you got

17:41

the cracks final answer to get a reward of zero

17:43

the rise but like maybe a really close but just

17:45

slightly off in some way so hopefully that there were

17:48

more model can tell you that. Some

17:50

experimented with giving that as a reward

17:52

both like at the at the very

17:54

last token and also like of interspersed

17:56

throughout says this idea of sparser word

17:58

versus the dense. Work. And.

18:01

Again, we found in both cases that saw.

18:03

You. Know the these mechanisms improve the sample

18:05

complexity of the learning algorithm. So like

18:07

if I give you an arm as

18:09

for sparser order against her word, you

18:12

need slightly less samples to converge, But

18:14

that peak performance which you get to

18:16

is. About the same

18:18

are slightly worse than if you just use you

18:20

know that the ground truth. Exact match

18:22

Texas Tech's maps more. Ah,

18:26

I'm. And. Is likely

18:28

you know we were thinking y sus the case

18:30

You know why. Why is it that even when

18:32

they give this, you know that a reward? The

18:34

algorithms all must roughly converge to the same thing.

18:38

And so are are you know?

18:40

Main hypothesis was that's okay. If.

18:42

These algorithms especially keep your aspirations are

18:45

all producing roughly the same data. If

18:47

there are producing roughly the like the

18:49

same exploration of the types of. Answers.

18:53

Which. Constitute you know like ooh,

18:55

ooh, that's it answers which they discover

18:57

and like that aspiration process. That

19:00

in when you know kind of explains phenomena

19:03

right because you know regardless of the algorithm,

19:05

if I give you more or less the

19:07

same data than they're probably going to perform

19:09

roughly the same. And. His contacts at

19:12

least. So. To verify that we

19:14

had to look at, you know, some measures

19:16

of. In a what?

19:18

What is there like? Pay. For

19:20

the output the each of these methodology

19:22

is or algorithms are producing and a

19:24

particular we were interested in to metric

19:26

so we were interested in looking it's.

19:29

The. Diversity of the solutions which

19:31

these algorithms like which each

19:33

algorithm generated. Ah,

19:35

I'm. So. Like you know

19:37

what is the unique number of. Solutions.

19:40

For a given problem where you judge to

19:43

solutions the san, it's like they have the

19:45

same order operations or something like that. Arithmetic

19:48

operations to that was one metric we were

19:50

looking at output diversity and then we were

19:52

also interested in if you look at the

19:54

pass at Ninety Six, ah

19:56

metrics which of these like trained

19:58

models which is If I sample the

20:00

model 96 times on each test question and then check

20:03

if I ever get the same test, if

20:06

I ever get the correct answer, that's

20:09

another metric I can use to evaluate these models'

20:11

performance. And that's also some measure

20:15

of output diversity because there's

20:17

this, roughly speaking, the

20:20

idea of a model being more diverse

20:22

means that if you sample it many

20:24

times, it'll have a broader set, broader

20:26

coverage of the types of solutions, and

20:28

it's odds of getting the correct solution

20:30

are better than the less

20:32

diverse model. And so their pass at 96 score

20:35

will be higher. So we were

20:37

comparing the pass at 96 score of these models,

20:39

and we found some very interesting things. So

20:42

for example, if you

20:44

do supervised fine tuning and

20:47

you train for two

20:49

epochs on some golden

20:51

ground truth data, your

20:54

pass at 96 score will be pretty good. So you can

20:56

get a pass at 96 score of 80% on llama27b if

21:03

you do this. But then the interesting thing is if

21:06

you train for longer than that, so if you

21:08

train for more than two epochs, your pass at

21:10

96 score will start to go down, and

21:13

suddenly it seems like your model's producing

21:15

less diverse solutions, and the exploration it's

21:17

engaging in is worse. So

21:20

somehow supervised fine tuning is not, too

21:24

much of it is damaging model diversity. I'm

21:27

not sure I'm clear on

21:29

the role and value

21:32

of diverse responses for

21:34

simple arithmetic problems. So

21:38

I agree, if you're trying to solve a math word problem

21:40

on the face of it, why do

21:42

you care if you are able to solve

21:44

it in multiple ways? So first of

21:46

all, I wanna clarify that even in these

21:48

simple math word problems, there are a surprisingly

21:50

diverse number of ways to solve them. So

21:52

on average, I can tell you

21:55

that the 7B model discovered five

21:57

different unique solutions per problem, which is kind

21:59

of... It's kind of surprising, but already

22:02

pretty diverse. So why is

22:04

this useful? I'm making the claim

22:06

it depends on your use case. So if you're interacting

22:08

with chat GPT, for example, and

22:11

you're just asking chat GPT, chat GPT

22:13

solved me this word problem. Then

22:16

I agree. You don't care about output

22:18

diversity. You just only care about chat

22:20

GPT getting the right answer. And

22:23

if it can do that reliably in the first try, that's

22:25

great. And so yeah, you

22:28

don't care. But let's say

22:30

you're interested in like, maybe

22:33

you have like some two different types

22:35

of difficulty of questions. So like you

22:37

have an easy set of math problems

22:39

and you have a hard set of

22:41

math word problems. Maybe

22:44

I can convince you that if you can

22:47

generate lots of diverse types of solutions to

22:49

the easy math word problems, somehow

22:51

this makes it easier to also generalize

22:53

to solving the harder math word problems.

22:56

A model with a greater agility with

22:58

regard to reasoning can probably do better

23:00

with the more complex problems, whereas for

23:03

the easier problems, it might be more

23:05

like memory or retrieval or something like

23:07

that. The reason I think we care

23:09

about this very concretely in the RL

23:12

context is because the

23:14

ability of the agents

23:17

that you're fine tuning to

23:19

produce diverse solutions directly

23:21

impacts the quality of the exploration it's

23:23

doing as it's being trained. If

23:26

I have a LLM which is being fine

23:28

tuned with RL and it's

23:30

like super over fit, has no

23:32

diversity, it's like just not

23:34

going to be able to discover that much

23:36

and not going to be able to learn

23:38

that much beyond what it already knows because

23:40

it's simply not generating surprising or

23:43

unexpected solutions. Whereas If

23:45

you have an LLM which is not

23:47

over fit and is able to generate

23:49

more diverse, interesting solutions, then you're going

23:52

to start to see the more complicated

23:54

type of generalization that you're interested in.

24:00

Duration and diversity, responses and temperature

24:02

was I mentioned earlier With Sir

24:04

Alan's temperature deathly will have an

24:07

effect on exploration. Ah and like

24:09

model up diversity. Ah,

24:12

In general in our experiments with six so

24:14

depends on the set up we have to

24:16

broad set ups. One we supervise fine tune

24:18

the model beforehand and then do are all

24:20

and the other be fine tune the model

24:22

from the pretend checkpoint and then to are

24:24

all. Ah, I'm in

24:26

the case of supervised fine tuning

24:28

set up. You can afford a

24:30

slightly higher temperature of like zero

24:32

point seven. Because. The

24:34

model already has a decent idea of what it has to do.

24:38

When you start from scratch you need to

24:40

have a lower temperature of like zero point

24:42

One Zero point two. Because our is a

24:44

model just doesn't have a good of sense

24:46

of like you know what actually constitutes a

24:48

correct solution. It's still learning the syntax that

24:50

kind of thing and so you can't really

24:53

afford to be slightly off. Ah I'm you

24:55

really need to like make the best of

24:57

all like the valid examples you have but

24:59

what you can do is as you go

25:01

through training in the like the the scratch

25:03

like the starting from scratch case. you can

25:05

start to a neil the temperature so that

25:07

you're producing. Like a higher

25:09

temperature solutions as the model is you

25:11

know better learning what constitutes a valid

25:14

solution used to name just the last

25:16

layer kind of how deeper you to

25:18

the in that kind of thing Neither

25:20

all different things you could play with

25:22

their ah but they're not fully specified

25:24

I'd I'd don't think er den st

25:26

by the Rl algorithm is ha yeah

25:29

yeah that I see I have misunderstood.

25:31

So yeah it from that sense like

25:33

what layers of the model are you

25:35

fine tuning that that's all the sand

25:37

between different. Algorithms. Ah, I'm.

25:41

Yeah. But like how you

25:43

actually compete the gradients and like the loss which

25:45

is used. That's our analysis and. It. Was. The

25:48

next spectacular. Yes,

25:50

another interesting thing is I'm.

25:53

So. I mentioned like. This.

25:55

Phenomenon of and this is observed in the original just

25:57

to make a paper, actually, so I don't want to

25:59

claim credit. But. The

26:01

recent phenomena where if you you know have

26:03

a model with your supervised fine tuning. If

26:05

you do have like to epochs it has

26:08

pretty good diversity but the diversity pretty quickly

26:10

decays after that. So. Like

26:12

hopefully have come as you at this point that it's

26:14

in your answers to maintain all diversity if you want

26:16

girl. So the nice thing

26:18

that happens when you do any kind

26:20

of oral fine tuning, ppl exploration, Anything

26:22

really? Is that's that pass

26:24

at Ninety Six parameter. Kind is

26:27

maintains like a convert, is pretty early in

26:29

training but at least it doesn't decays like

26:31

it maybe converges the eighty percent and then

26:33

it stays that eighty percent for the majority

26:35

of our off his the Rl training vs

26:38

in S a T it like it's this

26:40

you know model thing relic increases and decreases.

26:42

the somehow like oral fine tuning is able

26:44

to. Reinforce. And preserve

26:47

like. The. Type of diversity that

26:49

the model has like. Got picked up on

26:51

and started to learn. What? You're really

26:53

doing during all fine tuning as you're expanding

26:55

the kind of like this database of training

26:57

do you have and just like training, the

26:59

model on more diverse data allows it to

27:01

maintain more diverse outputs. is what I think

27:03

is really going on there. By. That's

27:06

kind of like one nice byproduct of

27:08

the url fine tuning is that you

27:10

maintain this level a diversity you wouldn't

27:12

otherwise get with supervised Clinton. There's this

27:14

interesting literature in classical are all which

27:16

is auto curricula. So like. As

27:18

you have some agents on a bunch of different

27:20

environments, some of the environments will be harder than

27:22

other environments and say you want to like in

27:25

a tree or the Me the an. Increase.

27:27

The difficulty? overtime? Yeah, yeah, that has

27:29

worked pretty well for classical oral. So

27:32

we tried a couple different algorithms

27:34

like that for an artist. Ah,

27:37

with the R L, I'm fine

27:39

tuning. To. To be concretely

27:41

tried Tlr which is. Prioritize,

27:43

Level replay of is one algorithm and

27:45

we also tried this idea called backtracking.

27:48

Where are you like starts the elam.

27:50

With. Like the answer to

27:52

assert lower you start the alarm

27:54

at like a solution. which

27:57

is almost on the not quite and so the hello

27:59

i'm only has to fill in the last one or

28:01

two steps. And then when the

28:03

LLM demonstrates it can fill in that last one

28:05

or two steps, it backs itself up slightly. And

28:08

now it has to fill in the last three

28:10

steps, if you can do that, the last four,

28:12

et cetera. So you're backtracking to

28:15

early and earlier steps in the solution. You

28:17

can see that as being easier than solving

28:19

everything from scratch. So

28:21

we tried these kind of ideas. And

28:24

again, there were maybe

28:26

some small benefits, but

28:29

they weren't really

28:31

inducing the type

28:33

of generalization

28:35

improvement that we were interested in. The

28:38

model is kind of already going to, you're going

28:40

to sample it, and it's going to produce solutions

28:43

in some specific way. And all

28:45

of these algorithms that we tried are not

28:47

really doing a good job of significantly changing

28:49

the sampling distribution of the model. And

28:52

so that's why the performance is all roughly the same. Given

28:56

what you learned in this research, if

29:01

you were working on LAMA3,

29:03

for example, and wanted to ensure that

29:05

it had better reasoning

29:08

capabilities, is there an obvious

29:10

way that you would integrate

29:12

in the results of this research? Yeah,

29:14

so I think at my point, my

29:16

opinion is exploration is the fundamental problem

29:18

here. I think it's clear that there

29:20

are a couple of things which clearly do work well

29:23

at improving the quality of model outputs.

29:26

So if you somehow have really high

29:28

quality data coming from another source, like

29:30

humans, for example, or GPT-4, that

29:34

clearly works very well. You can fine tune

29:36

a model on those, and the

29:39

type of solutions it will produce is

29:41

significantly changed by this. I'm

29:44

sure Met is already doing this, right? Like, of course, they're getting

29:46

humans to annotate high quality data.

29:49

If you want to take the extra step, what

29:51

you really want to do is you want

29:53

to start looking into comparing

29:56

and benchmarking how different algorithms

29:58

impact the Data. Equality and

30:00

Diversity of slake synthetic data that

30:02

you're generating Like people right now

30:05

are very interested in synthetic data

30:07

generation, but there isn't that much

30:09

investigation yet into. You know, how

30:11

does the quality of the data

30:14

generated by one algorithm compared to

30:16

the quality generated by another algorithm?

30:18

How does this lakes? You know, compare

30:21

as a function of six computes six

30:23

number of samples. Also, how's the diversity?

30:25

Like impacted like I think. Really?

30:28

You want to have like a rigorous benchmark

30:30

which tells you if I had this different

30:32

ways of generated synthetic data what is, How

30:34

do I get the best thing for my

30:36

buck when I am like comparing some you

30:39

know targets level of quality diversity in that

30:41

the output. You. Also have

30:43

recently been evolved and

30:45

benchmarking effort a dataset

30:48

called a are be

30:50

Advance Reasoning Benchmark Ah

30:52

what's the. Relationship or

30:54

contrast between that and yes, I make

30:56

a which he referred to. Yeah,

30:59

so Orbs Air B is just

31:02

impossibly difficult. Like like a

31:04

how difficult That Certainly I would not be

31:06

able to solve most the questions without a

31:08

reference solution. Arm. And that

31:10

like this is really the direction most benchmarks

31:12

are going nowadays because Allen's or yeah really

31:14

just are so capable. Ah,

31:17

You know you really like so like

31:19

a popular bands record now is graduates

31:21

Ah did stupid you a graduate something

31:23

questions. Which. Is similar to our

31:25

i think like a bit different in the sense

31:27

that one of the things we're really going for

31:29

an orb is we once like really high quality

31:32

really difficult questions. But. We also

31:34

don't just once. Questions.

31:36

Which have. Like. A numerical

31:39

final answer which you can check that way or I

31:41

select we certainly how those in the benchmark. Ah,

31:43

so really difficult questions that you have

31:46

to get to crux numerical final answer.

31:48

but we also have some questions which

31:50

investigate like the ability of the model

31:52

to do symbolic reasoning. so like maybe

31:54

the final answer some polynomial or something.

31:57

Which. Saw you need to get to

31:59

like a. A lot of you know more

32:01

advanced math. Questions of the discussions are like this

32:03

where like in a your answer your final answer

32:05

is not numerical. it's also symbolic. And

32:07

you know that this is obviously an important

32:10

type of question to be able to answer.

32:12

Also proofs like we have a couple high

32:14

level per questions on their that the model

32:16

and is to be able to solve and

32:18

like Reddick events you perform. Gp

32:20

for is like pretty good at doing arithmetic

32:23

right. Like I think a lot of people

32:25

have remarked on this of us had to

32:27

like multiply to five digit numbers. It probably

32:29

will be able to, but then. It's.

32:31

You are like asking Cpt for to add

32:33

like to very simple numbers like asking to

32:36

add you know ten and eleven or something

32:38

and the middle of a very long computation.

32:40

like you know it's integrating stars and it's

32:42

like something some power series of like doing

32:45

was a complicated stuff Cpt for will tend

32:47

to mess up very simple like arithmetic operations

32:49

which it was completely fine and like able

32:51

to do in like a clean context was

32:53

kind of funny right? Like this happens to

32:56

me all the time, like I'm doing some

32:58

hard math problem and like you know very

33:00

high level of. That stuff. But then when

33:02

I actually have to go and like yeah, simplify

33:05

I like suddenly you know like free isn't

33:07

like oh wait, says I'm making mistakes. Serious to

33:09

go back and take my word for Somehow it's

33:11

like. You. Know funny that vp

33:13

for exhibits this like same. I.

33:16

Was a characteristic that. Humans.

33:18

Do when they start to learn higher

33:21

level mathematics this kind of thing I

33:23

I don't know why exactly that. It's

33:25

some. It might

33:27

you know, in some way reflects just

33:29

the difficulty of recently overlong context. Actually,

33:31

I'm sure it does. Ah, I'm I'm

33:33

sure that's part of it. But also

33:35

I wonder if it in some way

33:38

reflects the training data itself. Like maybe.

33:40

He. Had the training data you have for

33:42

these types of more long complex questions is

33:45

filled with like arithmetic airs. This kind of

33:47

thing with you just don't attacks because like,

33:49

you know what you're looking a graduate level

33:51

question, you don't actually care too much about

33:54

the arithmetic itself. right? Like you're you're more

33:56

interested in the more complex, sophisticated reasoning. attention

33:59

is ah obviously a super overloaded word,

34:01

but it strikes me that there's maybe

34:03

something in what it's paying attention to

34:05

as it's, what it is, how it's

34:08

conceptualized the goal from a

34:10

distribution perspective of the output as

34:13

it's generating the tokens and

34:16

it maybe gets distracted by the

34:18

bigger picture and misses the smaller things or

34:20

something like that? It's distracted

34:22

by this integral and suddenly it can't

34:24

add like six and five or something

34:26

like that. It's expected to be a

34:28

lot more consistent. If it knows how to add six and five, it's

34:30

going to know how to add six and five. Right. Yeah,

34:33

which is fascinating that it's not like that. How's

34:35

it changed the way you think broadly about reasoning,

34:39

machine reasoning, LLM reasoning

34:41

and the future of

34:44

reasoning? We didn't get to touch on, nobody's

34:46

the algorithmic chain of thought. I'm

34:49

teaching LLMs to do addition or

34:51

multiplication, very simple algorithmic tasks. How

34:54

does their performance depend on like if I

34:56

start injecting little tidbits of noise here and

34:58

there, so like maybe I flip this digit

35:00

from five to seven or like I delete

35:02

this line at the calculator. It sounds a

35:04

little bit like kind of

35:06

the approaches that are often taken for like

35:09

interpretability research like let's fiddle with the input

35:11

data, make it a little noisy and see

35:13

how that changes things in

35:15

the model. Yeah, specifically looking at noise effects

35:17

training for chain of thought. In

35:20

preparation for this work, I spent a lot of time

35:22

thinking about like okay, why is chain of thought useful

35:24

to begin with, right? It's

35:27

also clearly very useful to humans. Humans

35:30

had spoken language for a

35:32

long time before writing things down

35:34

became useful. And suddenly I

35:36

think it's clear that once

35:38

humans started writing things down, this was a

35:40

significant shift in the trajectory of the human

35:43

race or whatever. I'm no authority on

35:45

history, so this is just my take. But

35:49

there's clearly like a huge power

35:51

to writing things down. And I think like

35:53

one way of interpreting it is it's

35:55

like a way of maintaining state. You

35:58

know, like if I am doing

36:00

this. I'm adding up two seven-digit

36:02

numbers. I have to do all of that in

36:04

my head, and I have to maintain the state in

36:06

individual neurons in my head. Whereas

36:12

if I'm writing things down and doing step-by-step chain

36:14

of thought reasoning, I have to offload most of

36:17

that state onto the sheet of paper. And

36:19

then I can just refer back to the sheet of paper for

36:21

certain subsets of information that I need. And

36:27

that's hugely useful. That just simplifies

36:30

significantly the type of computation that needs to

36:32

be done normally, certainly for humans,

36:34

and I think also for LLMs as well.

36:37

It literally makes the next token

36:39

prediction task much, much easier. Chain

36:41

of thought has proven to be

36:43

useful based

36:47

on chains

36:49

of thought that have naturally emerged in

36:52

the training data, but there's some way

36:54

to maybe take

36:56

advantage of or kind

36:58

of design for the way

37:01

LLMs might think about chain of

37:03

thought and train on more specific

37:07

chains or something. Getting

37:09

LLMs to do refinement

37:11

and this kind of thing, right? You

37:14

have some solution, and you try to fix the solution. And

37:18

often when humans are doing refinement, there's this whole

37:20

intermediate thought process, right? Like you have some sequence

37:22

of thoughts. You say, oh, there's an error here.

37:24

You try something, and that doesn't work. You try

37:26

another thing. But that's not

37:28

written down anywhere, right? So oftentimes, we get

37:30

to the correct final answer, but

37:33

we don't show our work. And

37:35

so a lot of that work doesn't show up

37:38

in the LLM training data, but I think putting

37:40

those intermediate steps in the LLM training data will

37:43

be super helpful. I

37:45

think that's something probably pretty much everyone agrees on

37:47

and most people are doing. Our

37:49

key questions are, A, how

37:52

do you characterize this type of noise in

37:54

chain of thought data? Like, are there different

37:56

types of noise which impacts the model differently

37:58

during training? D,

38:00

how can you quantitatively model the impact

38:02

of this different levels of noise on

38:05

trained model performance? If

38:07

you're training a

38:10

GPTN, you might

38:12

be able to profile some new

38:14

dataset that you have access to, characterize this

38:17

noise, and then understand the impact that it's

38:19

going to have on the model that's produced.

38:21

Yeah, that's exactly the idea. I

38:24

think it's important to be able to identify these factors

38:26

and say, like, yeah. So now

38:29

to explain what we found, so we

38:31

trained a small model. This was a

38:33

relatively small scale study. We trained a

38:35

small model, Pythia, for 410 million

38:38

parameters, on just

38:41

sequences of text produced by

38:44

simple functions on integers. So we did

38:47

arithmetic. We did, if you have a list

38:49

and you need to find the median in the list, we did

38:51

that. We

38:53

did sorting. We

38:55

did a bunch of GCD, greatest common divisor.

38:57

We did a bunch of different algorithms, which

39:00

produced chains of thought for. We

39:02

called them algorithmic chains of thought because they're algorithms.

39:05

And then we trained models on them.

39:09

And first of all, chain of

39:11

thought makes the task much easier.

39:13

If I am training even a

39:15

small GPT model to do arithmetic, it's much

39:17

easier if you write out the chain of

39:19

thought for the model versus training

39:21

from scratch, just going from X plus Y

39:24

to Z. That's much easier. So

39:27

then once we confirmed that, we

39:29

characterized different types of noise. So

39:32

we came up with two different types mainly.

39:34

So there's static noise, which is kind

39:37

of noise, which maybe makes local

39:39

changes to each chain

39:41

of thought. So maybe you

39:43

flip a digit or maybe

39:45

you delete a line, but that doesn't change

39:47

anything earlier or later in

39:49

the chain of thought. It's just that specific line. But

39:52

then we also have a kind of more

39:54

pernicious, more dangerous type of noise called dynamic

39:56

noise. And What dynamic noise does

39:59

is... When you're actually executing, like

40:01

when you're actually getting the chain of thought

40:03

which you then feed into the Llm. You.

40:06

Like. And just some

40:08

noise there. And the generation process so

40:10

like may be at a step or

40:12

I'm adding two numbers. I flipped one

40:15

of the did just that I'm adding,

40:17

producing the wrong answer and then that

40:19

you know, impacts all your calculations downstream.

40:21

So I suddenly dynamic noise has a

40:23

much larger, less localized impact on the

40:26

entire solution. Persisted moist distinction. Basically.

40:29

Kind of were in the

40:31

train of thought. It's. Like.

40:33

Sad ago towards the end and dynamic is

40:35

towards the beginning. Or let's say I ask

40:38

you to solve the problem for me like

40:40

generates and training data. I. Take

40:42

that that. maybe like somehow there's

40:44

some noise and that's so like,

40:46

ah, you know, when you're communicating

40:49

it to me. I'd like maybe

40:51

accidentally slip a digits. Era

40:53

like I accidently Missile Lion to like.

40:55

I'm not change any of the global

40:57

structure of or solution. I'm just like

40:59

changing very like local things that static

41:01

noise or as dynamic noise you make

41:04

a mistake when you're generating the solution.

41:06

And and that impacts the rest of the answer.

41:08

So like now the mistakes that you've made has

41:10

like you know mess that. Messes. Up Step Two

41:12

step to. Messes. Up Step three, all the way

41:15

down. the one that is dynamic noise. Really noise

41:17

in the training data? It could be yeah, like

41:19

it could be that. You know

41:21

Gp T to try to solve a problem and

41:23

it starts from pretty good place for them. A

41:25

totally veers off. Like. It it actually

41:27

just like does something totally unreasonable that would

41:29

be a former dynamic noise is infer like

41:31

a practical for a static noise the be

41:33

that maybe Tpc to nose or Tpp for

41:36

know like the high level stuff for how

41:38

to get to solution. So it does roughly

41:40

the right things but when you look at

41:42

a particular step it makes mistakes which doesn't

41:44

really matter but like take me the seeking

41:46

is wrong. That. Would be a former static

41:48

noise. One of the main findings of

41:50

this paper was. A when you have

41:52

the static noise. Ah like let's

41:54

say you have some dataset which has some level

41:56

of static noise. Let's. Say that

41:59

and ever. Single training sample.

42:01

You. Have Some noise. Okay, so like we,

42:04

we investigated different regimes like and some

42:06

regimes only thirty percent of the trainee

42:08

dataset had noise. But in this regime,

42:10

I'm telling you, every single samples infected.

42:12

So. That's already pretty bad at everything. I'm

42:14

single, sample has noise, and now I'm

42:16

also telling you. Okay, consider

42:19

the case where seventy percent of

42:21

all your digits. And your

42:23

dataset are like wrong or flats or something

42:25

like you've crafted seventy percent of the digits

42:27

in your training dataset. So. Now and

42:30

every single sample. And. Seventy percent

42:32

of numbers are just wrong. So.

42:35

Now. Fine. Tune your model on

42:37

that how you expected to do. I would think

42:39

that it would do okay. You. Think

42:41

so. Okay, what's your right? So if

42:43

if you'd add noise to every single

42:46

training sample and you like perhaps seventy

42:48

percent of the digits, It says

42:50

find it as one hundred percent accuracy.

42:53

Which. To me was pretty shocking that you like

42:55

you can add that much noise and like

42:57

how to be fine the situation breaks down

42:59

so if you if you add and noise

43:01

to ninety percent of the digits in every

43:03

single training sample. Then you start to

43:05

get in trouble and like your model just can't learn

43:07

anything. But I still

43:09

think it's like pretty remarkable that. Most.

43:12

Transformer miles or so robust that you

43:14

can inject that much noise and still

43:16

get one hundred percent performance. Lines

43:19

me of like Ninos to the Facebook posts

43:21

were like you've got some tax but like

43:23

the you know their numbers and set of

43:25

letters and things are backwards and stuff like

43:28

that and like the brain figures that out

43:30

I guess. So yeah so like there's some,

43:32

there's some structural thing which it's still up

43:35

with. Learn and generalize. Ah so the I

43:37

think that was kind of the most had

43:39

I want to put it. Sensational

43:42

finding in mice and my it.

43:44

I am surprised though that there's

43:47

there's some threshold or something. afterwards.

43:49

it kind of falls apart. are

43:51

in fact, I could have convince

43:53

myself that, ah, you

43:56

know more errors are better

43:58

because it focuses the model

44:00

on the algorithm and not

44:03

the role of the numbers.

44:07

For example, in an extreme case of that widget,

44:10

if you blanked out all of the individual numbers,

44:13

you're really focusing out on the algorithm and

44:15

not the numbers. I agree. You

44:18

would think that some level of noise induces

44:20

more generalization. I think that's true if the

44:22

level of noise is low enough. We

44:26

found that this is also roughly the case if

44:28

you delete lines. You

44:30

can delete something like

44:32

30% of all the lines in each solution. The

44:36

model still is fine. But if you delete more than 30%,

44:38

then accuracy starts to go down. Dynamic

44:42

noise was much more destructive, as you can probably guess. If

44:46

you have even 10% of

44:48

all the samples infected with dynamic

44:50

noise, then you

44:52

can't really learn anything. You're in real trouble.

44:55

What I'm most optimistic for is – I think

44:58

lots of other people have more or less said this,

45:00

but if you take some LLM

45:03

and then you combine it with some exterior system

45:06

that allows it to do more

45:08

complicated planning or, in my case,

45:10

allows it to explore and generate

45:12

much more diverse types of solutions,

45:15

I think that's the most interesting thing. I'm skeptical that

45:17

you're going to be able to get everything you want

45:19

just by using the LLM itself. But

45:22

when you combine it with these other types of systems,

45:25

which somehow broaden the type of

45:27

solutions it's able to discover, I think that's where

45:29

you're really going to get superhuman reasoning

45:31

performance that people are looking for. Well,

45:34

Alex, thanks so much for taking the time

45:36

to share a bit about what you're working

45:38

on. Very interesting stuff. Of

45:40

course. Thank

45:52

you.

Unlock more with Podchaser Pro

  • Audience Insights
  • Contact Information
  • Demographics
  • Charts
  • Sponsor History
  • and More!
Pro Features