Skip to content

Commit 3e01ada

Browse files
authored
add alternative builder method (#887)
* add alternative builder method * fix test
1 parent 37393b7 commit 3e01ada

File tree

3 files changed

+2339
-2283
lines changed

3 files changed

+2339
-2283
lines changed

py/llama_cloud_services/beta/classifier/client.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,24 @@ def __init__(
5252
self.file_client = FileClient(client, project_id, organization_id)
5353
self.polling_timeout = polling_timeout
5454

55+
@classmethod
56+
def from_api_key(
57+
cls,
58+
api_key: str,
59+
project_id: Optional[str] = None,
60+
organization_id: Optional[str] = None,
61+
base_url: Optional[str] = None,
62+
) -> "ClassifyClient":
63+
"""
64+
Create a classify client from an API key.
65+
"""
66+
client = AsyncLlamaCloud(token=api_key, base_url=base_url)
67+
return cls(
68+
client,
69+
project_id,
70+
organization_id,
71+
)
72+
5573
async def acreate_classify_job(
5674
self,
5775
rules: list[ClassifierRule],

py/tests/classifier/test_client.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,44 @@ async def test_classify_file_ids(
130130
assert item.result.type == expected_type
131131

132132

133+
@pytest.mark.asyncio
134+
async def test_classify_file_ids_from_api_key(
135+
e2e_test_settings: EndToEndTestSettings,
136+
file_client: FileClient,
137+
simple_pdf_file_path: str,
138+
research_paper_path: str,
139+
classification_rules: list[ClassifierRule],
140+
):
141+
"""Test classifying files by their IDs"""
142+
# Upload test files first to get their IDs
143+
pdf_file = await file_client.upload_file(simple_pdf_file_path)
144+
research_paper_file = await file_client.upload_file(research_paper_path)
145+
146+
classify_client = ClassifyClient.from_api_key(
147+
api_key=e2e_test_settings.LLAMA_CLOUD_API_KEY.get_secret_value(),
148+
base_url=e2e_test_settings.LLAMA_CLOUD_BASE_URL,
149+
project_id=pdf_file.project_id,
150+
organization_id=e2e_test_settings.LLAMA_CLOUD_ORGANIZATION_ID,
151+
)
152+
153+
# Classify the uploaded files
154+
results = await classify_client.aclassify_file_ids(
155+
rules=classification_rules, file_ids=[pdf_file.id, research_paper_file.id]
156+
)
157+
158+
assert isinstance(results, ClassifyJobResults)
159+
assert len(results.items) == 2
160+
161+
file_id_to_expected_type = {
162+
pdf_file.id: "number",
163+
research_paper_file.id: "research_paper",
164+
}
165+
# Verify each file got classified
166+
for item in results.items:
167+
expected_type = file_id_to_expected_type[item.file_id]
168+
assert item.result.type == expected_type
169+
170+
133171
@parameterize_sync_and_async
134172
@pytest.mark.asyncio
135173
async def test_classify_file_path(

0 commit comments

Comments
 (0)