Skip to content

Commit d8dc8f1

Browse files
feat(python-client): add new parameters (#118)
1 parent 55bd4fe commit d8dc8f1

File tree

9 files changed

+278
-40
lines changed

9 files changed

+278
-40
lines changed

clients/python/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,22 @@ class FinishReason(Enum):
133133
StopSequence = "stop_sequence"
134134

135135

136+
# Additional sequences when using the `best_of` parameter
137+
class BestOfSequence:
138+
# Generated text
139+
generated_text: str
140+
# Generation finish reason
141+
finish_reason: FinishReason
142+
# Number of generated tokens
143+
generated_tokens: int
144+
# Sampling seed if sampling was activated
145+
seed: Optional[int]
146+
# Prompt tokens
147+
prefill: List[PrefillToken]
148+
# Generated tokens
149+
tokens: List[Token]
150+
151+
136152
# `generate` details
137153
class Details:
138154
# Generation finish reason
@@ -145,6 +161,8 @@ class Details:
145161
prefill: List[PrefillToken]
146162
# Generated tokens
147163
tokens: List[Token]
164+
# Additional sequences when using the `best_of` parameter
165+
best_of_sequences: Optional[List[BestOfSequence]]
148166

149167

150168
# `generate` return value

clients/python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "text-generation"
3-
version = "0.2.1"
3+
version = "0.3.0"
44
description = "Hugging Face Text Generation Python Client"
55
license = "Apache-2.0"
66
authors = ["Olivier Dehaene <olivier@huggingface.co>"]

clients/python/tests/conftest.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44
from huggingface_hub.utils import build_hf_headers
55

66

7-
@pytest.fixture
8-
def bloom_model():
9-
return "bigscience/bloom"
10-
11-
127
@pytest.fixture
138
def flan_t5_xxl():
149
return "google/flan-t5-xxl"

clients/python/tests/test_client.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,32 @@
55
from text_generation.types import FinishReason, PrefillToken, Token
66

77

8-
def test_generate(bloom_url, hf_headers):
9-
client = Client(bloom_url, hf_headers)
8+
def test_generate(flan_t5_xxl_url, hf_headers):
9+
client = Client(flan_t5_xxl_url, hf_headers)
1010
response = client.generate("test", max_new_tokens=1)
1111

12-
assert response.generated_text == "."
12+
assert response.generated_text == ""
1313
assert response.details.finish_reason == FinishReason.Length
1414
assert response.details.generated_tokens == 1
1515
assert response.details.seed is None
1616
assert len(response.details.prefill) == 1
17-
assert response.details.prefill[0] == PrefillToken(
18-
id=9234, text="test", logprob=None
19-
)
17+
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
2018
assert len(response.details.tokens) == 1
2119
assert response.details.tokens[0] == Token(
22-
id=17, text=".", logprob=-1.75, special=False
20+
id=3, text=" ", logprob=-1.984375, special=False
2321
)
2422

2523

24+
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
25+
client = Client(flan_t5_xxl_url, hf_headers)
26+
response = client.generate("test", max_new_tokens=1, best_of=2, do_sample=True)
27+
28+
assert response.details.seed is not None
29+
assert response.details.best_of_sequences is not None
30+
assert len(response.details.best_of_sequences) == 1
31+
assert response.details.best_of_sequences[0].seed is not None
32+
33+
2634
def test_generate_not_found(fake_url, hf_headers):
2735
client = Client(fake_url, hf_headers)
2836
with pytest.raises(NotFoundError):
@@ -35,16 +43,16 @@ def test_generate_validation_error(flan_t5_xxl_url, hf_headers):
3543
client.generate("test", max_new_tokens=10_000)
3644

3745

38-
def test_generate_stream(bloom_url, hf_headers):
39-
client = Client(bloom_url, hf_headers)
46+
def test_generate_stream(flan_t5_xxl_url, hf_headers):
47+
client = Client(flan_t5_xxl_url, hf_headers)
4048
responses = [
4149
response for response in client.generate_stream("test", max_new_tokens=1)
4250
]
4351

4452
assert len(responses) == 1
4553
response = responses[0]
4654

47-
assert response.generated_text == "."
55+
assert response.generated_text == ""
4856
assert response.details.finish_reason == FinishReason.Length
4957
assert response.details.generated_tokens == 1
5058
assert response.details.seed is None
@@ -63,21 +71,19 @@ def test_generate_stream_validation_error(flan_t5_xxl_url, hf_headers):
6371

6472

6573
@pytest.mark.asyncio
66-
async def test_generate_async(bloom_url, hf_headers):
67-
client = AsyncClient(bloom_url, hf_headers)
74+
async def test_generate_async(flan_t5_xxl_url, hf_headers):
75+
client = AsyncClient(flan_t5_xxl_url, hf_headers)
6876
response = await client.generate("test", max_new_tokens=1)
6977

70-
assert response.generated_text == "."
78+
assert response.generated_text == ""
7179
assert response.details.finish_reason == FinishReason.Length
7280
assert response.details.generated_tokens == 1
7381
assert response.details.seed is None
7482
assert len(response.details.prefill) == 1
75-
assert response.details.prefill[0] == PrefillToken(
76-
id=9234, text="test", logprob=None
77-
)
83+
assert response.details.prefill[0] == PrefillToken(id=0, text="<pad>", logprob=None)
7884
assert len(response.details.tokens) == 1
7985
assert response.details.tokens[0] == Token(
80-
id=17, text=".", logprob=-1.75, special=False
86+
id=3, text=" ", logprob=-1.984375, special=False
8187
)
8288

8389

@@ -96,16 +102,16 @@ async def test_generate_async_validation_error(flan_t5_xxl_url, hf_headers):
96102

97103

98104
@pytest.mark.asyncio
99-
async def test_generate_stream_async(bloom_url, hf_headers):
100-
client = AsyncClient(bloom_url, hf_headers)
105+
async def test_generate_stream_async(flan_t5_xxl_url, hf_headers):
106+
client = AsyncClient(flan_t5_xxl_url, hf_headers)
101107
responses = [
102108
response async for response in client.generate_stream("test", max_new_tokens=1)
103109
]
104110

105111
assert len(responses) == 1
106112
response = responses[0]
107113

108-
assert response.generated_text == "."
114+
assert response.generated_text == ""
109115
assert response.details.finish_reason == FinishReason.Length
110116
assert response.details.generated_tokens == 1
111117
assert response.details.seed is None

clients/python/tests/test_inference_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def test_get_supported_models():
1414
assert isinstance(get_supported_models(), list)
1515

1616

17-
def test_client(bloom_model):
18-
client = InferenceAPIClient(bloom_model)
17+
def test_client(flan_t5_xxl):
18+
client = InferenceAPIClient(flan_t5_xxl)
1919
assert isinstance(client, Client)
2020

2121

@@ -24,8 +24,8 @@ def test_client_unsupported_model(unsupported_model):
2424
InferenceAPIClient(unsupported_model)
2525

2626

27-
def test_async_client(bloom_model):
28-
client = InferenceAPIAsyncClient(bloom_model)
27+
def test_async_client(flan_t5_xxl):
28+
client = InferenceAPIAsyncClient(flan_t5_xxl)
2929
assert isinstance(client, AsyncClient)
3030

3131

clients/python/tests/test_types.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
11
import pytest
22

3-
from text_generation.types import Parameters
3+
from text_generation.types import Parameters, Request
44
from text_generation.errors import ValidationError
55

66

77
def test_parameters_validation():
8+
# Test best_of
9+
Parameters(best_of=1)
10+
with pytest.raises(ValidationError):
11+
Parameters(best_of=0)
12+
with pytest.raises(ValidationError):
13+
Parameters(best_of=-1)
14+
Parameters(best_of=2, do_sample=True)
15+
with pytest.raises(ValidationError):
16+
Parameters(best_of=2)
17+
818
# Test repetition_penalty
919
Parameters(repetition_penalty=1)
1020
with pytest.raises(ValidationError):
@@ -32,8 +42,41 @@ def test_parameters_validation():
3242
Parameters(top_k=-1)
3343

3444
# Test top_p
35-
Parameters(top_p=1)
45+
Parameters(top_p=0.5)
3646
with pytest.raises(ValidationError):
3747
Parameters(top_p=0)
3848
with pytest.raises(ValidationError):
3949
Parameters(top_p=-1)
50+
with pytest.raises(ValidationError):
51+
Parameters(top_p=1)
52+
53+
# Test truncate
54+
Parameters(truncate=1)
55+
with pytest.raises(ValidationError):
56+
Parameters(truncate=0)
57+
with pytest.raises(ValidationError):
58+
Parameters(truncate=-1)
59+
60+
# Test typical_p
61+
Parameters(typical_p=0.5)
62+
with pytest.raises(ValidationError):
63+
Parameters(typical_p=0)
64+
with pytest.raises(ValidationError):
65+
Parameters(typical_p=-1)
66+
with pytest.raises(ValidationError):
67+
Parameters(typical_p=1)
68+
69+
70+
def test_request_validation():
71+
Request(inputs="test")
72+
73+
with pytest.raises(ValidationError):
74+
Request(inputs="")
75+
76+
Request(inputs="test", stream=True)
77+
Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True))
78+
79+
with pytest.raises(ValidationError):
80+
Request(
81+
inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True
82+
)

clients/python/text_generation/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = "0.2.1"
15+
__version__ = "0.3.0"
1616

1717
from text_generation.client import Client, AsyncClient
1818
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient

0 commit comments

Comments
 (0)