14
14
import pytest
15
15
from aiohttp import ClientConnectorError , ClientOSError , ServerDisconnectedError
16
16
from docker .errors import NotFound
17
- from loguru import logger
18
- from test_model import TEST_CONFIGS
19
- from text_generation import AsyncClient
20
- from text_generation .types import Response
17
+ import logging
18
+ from huggingface_hub import AsyncInferenceClient , TextGenerationOutput
19
+ import huggingface_hub
20
+
21
+ logging .basicConfig (
22
+ level = logging .INFO ,
23
+ format = "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>" ,
24
+ stream = sys .stdout ,
25
+ )
26
+ logger = logging .getLogger (__file__ )
21
27
22
28
# Use the latest image from the local docker build
23
29
DOCKER_IMAGE = os .getenv ("DOCKER_IMAGE" , "tgi-gaudi" )
24
30
DOCKER_VOLUME = os .getenv ("DOCKER_VOLUME" , None )
25
- HF_TOKEN = os . getenv ( "HF_TOKEN" , None )
31
+ HF_TOKEN = huggingface_hub . get_token ( )
26
32
27
33
assert (
28
34
HF_TOKEN is not None
48
54
"cap_add" : ["sys_nice" ],
49
55
}
50
56
51
- logger .add (
52
- sys .stderr ,
53
- format = "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>" ,
54
- level = "INFO" ,
55
- )
56
-
57
57
58
58
def stream_container_logs (container , test_name ):
59
59
"""Stream container logs in a separate thread."""
@@ -69,9 +69,15 @@ def stream_container_logs(container, test_name):
69
69
logger .error (f"Error streaming container logs: { str (e )} " )
70
70
71
71
72
+ class TestClient (AsyncInferenceClient ):
73
+ def __init__ (self , service_name : str , base_url : str ):
74
+ super ().__init__ (model = base_url )
75
+ self .service_name = service_name
76
+
77
+
72
78
class LauncherHandle :
73
- def __init__ (self , port : int ):
74
- self .client = AsyncClient ( f"http://localhost:{ port } " , timeout = 3600 )
79
+ def __init__ (self , service_name : str , port : int ):
80
+ self .client = TestClient ( service_name , f"http://localhost:{ port } " )
75
81
76
82
def _inner_health (self ):
77
83
raise NotImplementedError
@@ -87,7 +93,7 @@ async def health(self, timeout: int = 60):
87
93
raise RuntimeError ("Launcher crashed" )
88
94
89
95
try :
90
- await self .client .generate ("test" )
96
+ await self .client .text_generation ("test" , max_new_tokens = 1 )
91
97
elapsed = time .time () - start_time
92
98
logger .info (f"Health check passed after { elapsed :.1f} s" )
93
99
return
@@ -111,7 +117,8 @@ async def health(self, timeout: int = 60):
111
117
112
118
class ContainerLauncherHandle (LauncherHandle ):
113
119
def __init__ (self , docker_client , container_name , port : int ):
114
- super (ContainerLauncherHandle , self ).__init__ (port )
120
+ service_name = container_name # Use container name as service name
121
+ super (ContainerLauncherHandle , self ).__init__ (service_name , port )
115
122
self .docker_client = docker_client
116
123
self .container_name = container_name
117
124
@@ -132,7 +139,8 @@ def _inner_health(self) -> bool:
132
139
133
140
class ProcessLauncherHandle (LauncherHandle ):
134
141
def __init__ (self , process , port : int ):
135
- super (ProcessLauncherHandle , self ).__init__ (port )
142
+ service_name = "process" # Use generic name for process launcher
143
+ super (ProcessLauncherHandle , self ).__init__ (service_name , port )
136
144
self .process = process
137
145
138
146
def _inner_health (self ) -> bool :
@@ -151,11 +159,13 @@ def data_volume():
151
159
152
160
153
161
@pytest .fixture (scope = "module" )
154
- def launcher ( data_volume ):
162
+ def gaudi_launcher ( ):
155
163
@contextlib .contextmanager
156
164
def docker_launcher (
157
165
model_id : str ,
158
166
test_name : str ,
167
+ tgi_args : List [str ] = None ,
168
+ env_config : dict = None ,
159
169
):
160
170
logger .info (
161
171
f"Starting docker launcher for model { model_id } and test { test_name } "
@@ -183,32 +193,40 @@ def get_free_port():
183
193
)
184
194
container .stop ()
185
195
container .wait ()
196
+ container .remove ()
197
+ logger .info (f"Removed existing container { container_name } " )
186
198
except NotFound :
187
199
pass
188
200
except Exception as e :
189
201
logger .error (f"Error handling existing container: { str (e )} " )
190
202
191
- model_name = next (
192
- name for name , cfg in TEST_CONFIGS .items () if cfg ["model_id" ] == model_id
193
- )
194
-
195
- tgi_args = TEST_CONFIGS [model_name ]["args" ].copy ()
203
+ if tgi_args is None :
204
+ tgi_args = []
205
+ else :
206
+ tgi_args = tgi_args .copy ()
196
207
197
208
env = BASE_ENV .copy ()
198
209
199
210
# Add model_id to env
200
211
env ["MODEL_ID" ] = model_id
201
212
202
- # Add env config that is definied in the fixture parameter
203
- if " env_config" in TEST_CONFIGS [ model_name ] :
204
- env .update (TEST_CONFIGS [ model_name ][ " env_config" ] .copy ())
213
+ # Add env config that is defined in the fixture parameter
214
+ if env_config is not None :
215
+ env .update (env_config .copy ())
205
216
206
- volumes = [f"{ DOCKER_VOLUME } :/data" ]
217
+ volumes = []
218
+ if DOCKER_VOLUME :
219
+ volumes = [f"{ DOCKER_VOLUME } :/data" ]
207
220
logger .debug (f"Using volume { volumes } " )
208
221
209
222
try :
223
+ logger .debug (f"Using command { tgi_args } " )
210
224
logger .info (f"Creating container with name { container_name } " )
211
225
226
+ logger .debug (f"Using environment { env } " )
227
+ logger .debug (f"Using volumes { volumes } " )
228
+ logger .debug (f"HABANA_RUN_ARGS { HABANA_RUN_ARGS } " )
229
+
212
230
# Log equivalent docker run command for debugging, this is not actually executed
213
231
container = client .containers .run (
214
232
DOCKER_IMAGE ,
@@ -271,15 +289,16 @@ def get_free_port():
271
289
272
290
273
291
@pytest .fixture (scope = "module" )
274
- def generate_load ():
292
+ def gaudi_generate_load ():
275
293
async def generate_load_inner (
276
- client : AsyncClient , prompt : str , max_new_tokens : int , n : int
277
- ) -> List [Response ]:
294
+ client : AsyncInferenceClient , prompt : str , max_new_tokens : int , n : int
295
+ ) -> List [TextGenerationOutput ]:
278
296
try :
279
297
futures = [
280
- client .generate (
298
+ client .text_generation (
281
299
prompt ,
282
300
max_new_tokens = max_new_tokens ,
301
+ details = True ,
283
302
decoder_input_details = True ,
284
303
)
285
304
for _ in range (n )
0 commit comments