Skip to content

Commit b1482d9

Browse files
breaking(router): modify /generate API to only return generated text (#50)
@njhill, @yk FYI generated_text was concatenated to the user prompt for legacy reason. We want to remove this behaviour as we don't think it is useful and even detrimonial to usability. We also remove the unused Vec.
1 parent 7b870e1 commit b1482d9

File tree

9 files changed

+23
-24
lines changed

9 files changed

+23
-24
lines changed

launcher/tests/bloom_560m.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,6 @@
118118
]
119119
]
120120
},
121-
"generated_text": "Test request.get(\"action\");\n if (action == null) {\n throw new RuntimeException"
121+
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException"
122122
}
123123
]

launcher/tests/integration_tests.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ fn test_model(
9797
launcher.terminate().unwrap();
9898
launcher.wait().unwrap();
9999

100-
let mut results: Vec<GeneratedText> = res.unwrap().json().unwrap();
101-
results.pop().unwrap()
100+
let result: GeneratedText = res.unwrap().json().unwrap();
101+
result
102102
}
103103

104104
fn read_json(name: &str) -> GeneratedText {

router/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
mod infer;
21
/// Text Generation Inference Webserver
2+
3+
mod infer;
34
mod queue;
45
pub mod server;
56
mod validation;

router/src/server.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,10 @@ async fn generate(
125125
tracing::info!("Output: {}", response.generated_text.text);
126126

127127
// Send response
128-
let response = vec![GenerateResponse {
128+
let response = GenerateResponse {
129129
generated_text: response.generated_text.text,
130130
details,
131-
}];
131+
};
132132
Ok((headers, Json(response)))
133133
}
134134

server/tests/models/test_bloom.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
141141
assert len(generations) == 1
142142
assert (
143143
generations[0].generated_text.text
144-
== "TestTestTestTestTestTestTestTestTestTestTest"
144+
== "TestTestTestTestTestTestTestTestTestTest"
145145
)
146146
assert generations[0].request_id == default_bloom_batch.requests[0].id
147147
assert (
@@ -165,7 +165,7 @@ def test_causal_lm_generate_token_completion_multi(
165165
assert next_batch is not None
166166

167167
assert len(generations) == 2
168-
assert generations[1].generated_text.text == "TestTestTestTestTestTest"
168+
assert generations[1].generated_text.text == "TestTestTestTestTest"
169169
assert (
170170
generations[1].request_id == default_multi_requests_bloom_batch.requests[1].id
171171
)
@@ -188,7 +188,7 @@ def test_causal_lm_generate_token_completion_multi(
188188
assert len(generations) == 1
189189
assert (
190190
generations[0].generated_text.text
191-
== "TestTestTestTestTestTestTestTestTestTestTest"
191+
== "TestTestTestTestTestTestTestTestTestTest"
192192
)
193193
assert (
194194
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id
@@ -261,7 +261,7 @@ def test_batch_concatenate(
261261
assert next_batch is not None
262262

263263
assert len(generations) == 3
264-
assert generations[2].generated_text.text == "TestTestTestTestTestTest"
264+
assert generations[2].generated_text.text == "TestTestTestTestTest"
265265
assert (
266266
generations[2].request_id == default_multi_requests_bloom_batch.requests[1].id
267267
)
@@ -284,7 +284,7 @@ def test_batch_concatenate(
284284
assert len(generations) == 2
285285
assert (
286286
generations[0].generated_text.text
287-
== "TestTestTestTestTestTestTestTestTestTestTest"
287+
== "TestTestTestTestTestTestTestTestTestTest"
288288
)
289289
assert generations[0].request_id == default_bloom_batch.requests[0].id
290290
assert (
@@ -307,7 +307,7 @@ def test_batch_concatenate(
307307
assert len(generations) == 1
308308
assert (
309309
generations[0].generated_text.text
310-
== "TestTestTestTestTestTestTestTestTestTestTest"
310+
== "TestTestTestTestTestTestTestTestTestTest"
311311
)
312312
assert (
313313
generations[0].request_id == default_multi_requests_bloom_batch.requests[0].id

server/tests/models/test_causal_lm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_causal_lm_generate_token_completion(
138138
assert next_batch is None
139139

140140
assert len(generations) == 1
141-
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
141+
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
142142
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
143143
assert (
144144
generations[0].generated_text.generated_tokens
@@ -161,7 +161,7 @@ def test_causal_lm_generate_token_completion_multi(
161161
assert next_batch is not None
162162

163163
assert len(generations) == 2
164-
assert generations[1].generated_text.text == "Test.java:784)"
164+
assert generations[1].generated_text.text == ".java:784)"
165165
assert (
166166
generations[1].request_id
167167
== default_multi_requests_causal_lm_batch.requests[1].id
@@ -183,7 +183,7 @@ def test_causal_lm_generate_token_completion_multi(
183183
assert next_batch is None
184184

185185
assert len(generations) == 1
186-
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
186+
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
187187
assert (
188188
generations[0].request_id
189189
== default_multi_requests_causal_lm_batch.requests[0].id
@@ -255,7 +255,7 @@ def test_batch_concatenate(
255255
assert next_batch is not None
256256

257257
assert len(generations) == 3
258-
assert generations[2].generated_text.text == "Test.java:784)"
258+
assert generations[2].generated_text.text == ".java:784)"
259259
assert (
260260
generations[2].request_id
261261
== default_multi_requests_causal_lm_batch.requests[1].id
@@ -277,7 +277,7 @@ def test_batch_concatenate(
277277
assert next_batch is not None
278278

279279
assert len(generations) == 2
280-
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
280+
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
281281
assert generations[0].request_id == default_causal_lm_batch.requests[0].id
282282
assert (
283283
generations[0].generated_text.generated_tokens
@@ -297,7 +297,7 @@ def test_batch_concatenate(
297297
assert next_batch is None
298298

299299
assert len(generations) == 1
300-
assert generations[0].generated_text.text == "Test.java:784) at net.minecraft."
300+
assert generations[0].generated_text.text == ".java:784) at net.minecraft."
301301
assert (
302302
generations[0].request_id
303303
== default_multi_requests_causal_lm_batch.requests[0].id

server/tests/models/test_santacoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
5757
assert next_batch is None
5858

5959
assert len(generations) == 1
60-
assert generations[0].generated_text.text == "def test_get_all_users_with_"
60+
assert generations[0].generated_text.text == " test_get_all_users_with_"
6161
assert generations[0].request_id == batch.requests[0].id
6262
assert (
6363
generations[0].generated_text.generated_tokens
@@ -84,7 +84,7 @@ def test_fim_santacoder_generate_token_completion(
8484
assert len(generations) == 1
8585
assert (
8686
generations[0].generated_text.text
87-
== """<fim-prefix>def<fim-suffix>world<fim-middle>ineProperty(exports, "__esModule", { value"""
87+
== """ineProperty(exports, "__esModule", { value"""
8888
)
8989
assert generations[0].request_id == batch.requests[0].id
9090
assert (

server/text_generation/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
def get_model(
3333
model_name: str, revision: Optional[str], sharded: bool, quantize: bool
3434
) -> Model:
35-
config = AutoConfig.from_pretrained(model_name)
35+
config = AutoConfig.from_pretrained(model_name, revision=revision)
3636

3737
if config.model_type == "bloom":
3838
if sharded:

server/text_generation/models/causal_lm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,11 +360,9 @@ def generate_token(
360360

361361
if stop:
362362
# Decode generated tokens
363-
generated_text = self.decode(
363+
output_text = self.decode(
364364
all_input_ids[-stopping_criteria.current_tokens :, 0]
365365
)
366-
output_text = request.inputs + generated_text
367-
368366
# Get seed
369367
if isinstance(next_token_chooser.choice, Sampling):
370368
seed = next_token_chooser.choice.seed

0 commit comments

Comments
 (0)