Skip to content

Commit 61a696b

Browse files
authored
add file names in return values (#888)
1 parent 3e01ada commit 61a696b

File tree

3 files changed

+74
-10
lines changed

3 files changed

+74
-10
lines changed

py/llama_cloud_services/beta/classifier/client.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from llama_cloud_services.constants import POLLING_TIMEOUT_SECONDS
1818
from llama_cloud_services.utils import is_terminal_status, augment_async_errors
1919
from llama_index.core.async_utils import DEFAULT_NUM_WORKERS, run_jobs
20+
from llama_cloud_services.beta.classifier.types import (
21+
ClassifyJobResultsWithFiles,
22+
)
2023

2124

2225
class ClassificationOutput(BaseModel):
@@ -170,19 +173,20 @@ async def aclassify_file_path(
170173
file_input_path: str,
171174
parsing_configuration: Optional[ClassifyParsingConfiguration] = None,
172175
raise_on_error: bool = True,
173-
) -> ClassifyJobResults:
176+
) -> ClassifyJobResultsWithFiles:
174177
file = await self.file_client.upload_file(file_input_path)
175-
return await self.aclassify_file_ids(
178+
results = await self.aclassify_file_ids(
176179
rules, [file.id], parsing_configuration, raise_on_error
177180
)
181+
return ClassifyJobResultsWithFiles.from_classify_job_results(results, [file])
178182

179183
def classify_file_path(
180184
self,
181185
rules: list[ClassifierRule],
182186
file_input_path: str,
183187
parsing_configuration: Optional[ClassifyParsingConfiguration] = None,
184188
raise_on_error: bool = True,
185-
) -> ClassifyJobResults:
189+
) -> ClassifyJobResultsWithFiles:
186190
with augment_async_errors():
187191
return asyncio.run(
188192
self.aclassify_file_path(
@@ -198,25 +202,26 @@ async def aclassify_file_paths(
198202
raise_on_error: bool = True,
199203
workers: int = DEFAULT_NUM_WORKERS,
200204
show_progress: bool = False,
201-
) -> ClassifyJobResults:
205+
) -> ClassifyJobResultsWithFiles:
202206
coroutines = [self.file_client.upload_file(path) for path in file_input_paths]
203207
files: list[File] = await run_jobs(
204208
coroutines,
205209
show_progress=show_progress,
206210
workers=workers,
207211
desc="Uploading files for classification",
208212
)
209-
return await self.aclassify_file_ids(
213+
results = await self.aclassify_file_ids(
210214
rules, [file.id for file in files], parsing_configuration, raise_on_error
211215
)
216+
return ClassifyJobResultsWithFiles.from_classify_job_results(results, files)
212217

213218
def classify_file_paths(
214219
self,
215220
rules: list[ClassifierRule],
216221
file_input_paths: list[str],
217222
parsing_configuration: Optional[ClassifyParsingConfiguration] = None,
218223
raise_on_error: bool = True,
219-
) -> ClassifyJobResults:
224+
) -> ClassifyJobResultsWithFiles:
220225
with augment_async_errors():
221226
return asyncio.run(
222227
self.aclassify_file_paths(
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from llama_cloud.types.classify_job_results import ClassifyJobResults
2+
from llama_cloud.types.file_classification import FileClassification
3+
from llama_cloud.types.file import File
4+
5+
6+
class FileClassificationWithFile(FileClassification):
7+
"""
8+
File classification with file object.
9+
"""
10+
11+
file: File
12+
13+
@classmethod
14+
def from_file_classification(
15+
cls, file_classification: FileClassification, file: File
16+
) -> "FileClassificationWithFile":
17+
if file_classification.file_id != file.id:
18+
raise ValueError(
19+
f"File classification ID {file_classification.id} does not match file ID {file.id}"
20+
)
21+
ctor_args = {
22+
**file_classification.dict(),
23+
"file": file,
24+
}
25+
return cls(**ctor_args)
26+
27+
28+
class ClassifyJobResultsWithFiles(ClassifyJobResults):
29+
"""
30+
Classify job results with file objects.
31+
"""
32+
33+
items: list[FileClassificationWithFile]
34+
35+
@classmethod
36+
def from_classify_job_results(
37+
cls, classify_job_results: ClassifyJobResults, files: list[File]
38+
) -> "ClassifyJobResultsWithFiles":
39+
if len(classify_job_results.items) != len(files):
40+
raise ValueError(
41+
f"Number of classify job results {len(classify_job_results.items)} does not match number of files {len(files)}"
42+
)
43+
# create mapping of file classification result to file object
44+
file_id_to_file: dict[str, File] = {file.id: file for file in files}
45+
file_classification_to_file: list[tuple[FileClassification, File]] = []
46+
for item in classify_job_results.items:
47+
if item.file_id not in file_id_to_file:
48+
raise ValueError(
49+
f"File classification result {item.id} has file ID {item.file_id} that does not match any provided file ID"
50+
)
51+
file_classification_to_file.append((item, file_id_to_file[item.file_id]))
52+
53+
# create a list of file classification with file objects
54+
ctor_args = classify_job_results.dict()
55+
ctor_args["items"] = [
56+
FileClassificationWithFile.from_file_classification(item, file)
57+
for item, file in file_classification_to_file
58+
]
59+
return cls(**ctor_args)

py/tests/classifier/test_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
from llama_cloud.client import AsyncLlamaCloud
44
from llama_cloud.types import Project, ClassifierRule, ClassifyJobResults
5+
from llama_cloud_services.beta.classifier.types import ClassifyJobResultsWithFiles
56
from llama_cloud_services.beta.classifier.client import ClassifyClient
67
from llama_cloud_services.files.client import FileClient
78
from llama_cloud.errors.unprocessable_entity_error import UnprocessableEntityError
@@ -187,7 +188,7 @@ async def test_classify_file_path(
187188
rules=classification_rules, file_input_path=simple_pdf_file_path
188189
)
189190

190-
assert isinstance(results, ClassifyJobResults)
191+
assert isinstance(results, ClassifyJobResultsWithFiles)
191192
assert len(results.items) == 1
192193

193194
# Verify the file got classified
@@ -218,7 +219,7 @@ async def test_classify_file_paths(
218219
file_input_paths=[simple_pdf_file_path, research_paper_path],
219220
)
220221

221-
assert isinstance(results, ClassifyJobResults)
222+
assert isinstance(results, ClassifyJobResultsWithFiles)
222223
assert len(results.items) == 2
223224

224225
file_name_to_expected_type = {
@@ -227,8 +228,7 @@ async def test_classify_file_paths(
227228
}
228229
# Verify each file got classified
229230
for item in results.items:
230-
file = await file_client.get_file(item.file_id)
231-
expected_type = file_name_to_expected_type[file.name]
231+
expected_type = file_name_to_expected_type[item.file.name]
232232
assert item.result.type == expected_type
233233

234234

0 commit comments

Comments
 (0)