5
5
from text_generation .types import FinishReason , PrefillToken , Token
6
6
7
7
8
- def test_generate (bloom_url , hf_headers ):
9
- client = Client (bloom_url , hf_headers )
8
+ def test_generate (flan_t5_xxl_url , hf_headers ):
9
+ client = Client (flan_t5_xxl_url , hf_headers )
10
10
response = client .generate ("test" , max_new_tokens = 1 )
11
11
12
- assert response .generated_text == ". "
12
+ assert response .generated_text == ""
13
13
assert response .details .finish_reason == FinishReason .Length
14
14
assert response .details .generated_tokens == 1
15
15
assert response .details .seed is None
16
16
assert len (response .details .prefill ) == 1
17
- assert response .details .prefill [0 ] == PrefillToken (
18
- id = 9234 , text = "test" , logprob = None
19
- )
17
+ assert response .details .prefill [0 ] == PrefillToken (id = 0 , text = "<pad>" , logprob = None )
20
18
assert len (response .details .tokens ) == 1
21
19
assert response .details .tokens [0 ] == Token (
22
- id = 17 , text = ". " , logprob = - 1.75 , special = False
20
+ id = 3 , text = " " , logprob = - 1.984375 , special = False
23
21
)
24
22
25
23
24
+ def test_generate_best_of (flan_t5_xxl_url , hf_headers ):
25
+ client = Client (flan_t5_xxl_url , hf_headers )
26
+ response = client .generate ("test" , max_new_tokens = 1 , best_of = 2 , do_sample = True )
27
+
28
+ assert response .details .seed is not None
29
+ assert response .details .best_of_sequences is not None
30
+ assert len (response .details .best_of_sequences ) == 1
31
+ assert response .details .best_of_sequences [0 ].seed is not None
32
+
33
+
26
34
def test_generate_not_found (fake_url , hf_headers ):
27
35
client = Client (fake_url , hf_headers )
28
36
with pytest .raises (NotFoundError ):
@@ -35,16 +43,16 @@ def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
35
43
client .generate ("test" , max_new_tokens = 10_000 )
36
44
37
45
38
- def test_generate_stream (bloom_url , hf_headers ):
39
- client = Client (bloom_url , hf_headers )
46
+ def test_generate_stream (flan_t5_xxl_url , hf_headers ):
47
+ client = Client (flan_t5_xxl_url , hf_headers )
40
48
responses = [
41
49
response for response in client .generate_stream ("test" , max_new_tokens = 1 )
42
50
]
43
51
44
52
assert len (responses ) == 1
45
53
response = responses [0 ]
46
54
47
- assert response .generated_text == ". "
55
+ assert response .generated_text == ""
48
56
assert response .details .finish_reason == FinishReason .Length
49
57
assert response .details .generated_tokens == 1
50
58
assert response .details .seed is None
@@ -63,21 +71,19 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
63
71
64
72
65
73
@pytest .mark .asyncio
66
- async def test_generate_async (bloom_url , hf_headers ):
67
- client = AsyncClient (bloom_url , hf_headers )
74
+ async def test_generate_async (flan_t5_xxl_url , hf_headers ):
75
+ client = AsyncClient (flan_t5_xxl_url , hf_headers )
68
76
response = await client .generate ("test" , max_new_tokens = 1 )
69
77
70
- assert response .generated_text == ". "
78
+ assert response .generated_text == ""
71
79
assert response .details .finish_reason == FinishReason .Length
72
80
assert response .details .generated_tokens == 1
73
81
assert response .details .seed is None
74
82
assert len (response .details .prefill ) == 1
75
- assert response .details .prefill [0 ] == PrefillToken (
76
- id = 9234 , text = "test" , logprob = None
77
- )
83
+ assert response .details .prefill [0 ] == PrefillToken (id = 0 , text = "<pad>" , logprob = None )
78
84
assert len (response .details .tokens ) == 1
79
85
assert response .details .tokens [0 ] == Token (
80
- id = 17 , text = ". " , logprob = - 1.75 , special = False
86
+ id = 3 , text = " " , logprob = - 1.984375 , special = False
81
87
)
82
88
83
89
@@ -96,16 +102,16 @@ async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
96
102
97
103
98
104
@pytest .mark .asyncio
99
- async def test_generate_stream_async (bloom_url , hf_headers ):
100
- client = AsyncClient (bloom_url , hf_headers )
105
+ async def test_generate_stream_async (flan_t5_xxl_url , hf_headers ):
106
+ client = AsyncClient (flan_t5_xxl_url , hf_headers )
101
107
responses = [
102
108
response async for response in client .generate_stream ("test" , max_new_tokens = 1 )
103
109
]
104
110
105
111
assert len (responses ) == 1
106
112
response = responses [0 ]
107
113
108
- assert response .generated_text == ". "
114
+ assert response .generated_text == ""
109
115
assert response .details .finish_reason == FinishReason .Length
110
116
assert response .details .generated_tokens == 1
111
117
assert response .details .seed is None
0 commit comments