@@ -36,6 +36,28 @@ def run_tool_calling_llm(llm, request_params):
36
36
]
37
37
request_params ["tools" ] = [tool_schema ]
38
38
39
+ last_tool_id = 0
40
+ for i , message in enumerate (request_params ["messages" ]):
41
+ if "function_call" in message :
42
+ last_tool_id += 1
43
+ function = message .pop ("function_call" )
44
+ message ["tool_calls" ] = [
45
+ {
46
+ "id" : "toolu_" + str (last_tool_id ),
47
+ "type" : "function" ,
48
+ "function" : function ,
49
+ }
50
+ ]
51
+ if message ["role" ] == "function" :
52
+ if i != 0 and request_params ["messages" ][i - 1 ]["role" ] == "tool" :
53
+ request_params ["messages" ][i ]["content" ] += message ["content" ]
54
+ message = None
55
+ else :
56
+ message ["role" ] = "tool"
57
+ message ["tool_call_id" ] = "toolu_" + str (last_tool_id )
58
+
59
+ request_params ["messages" ] = [m for m in request_params ["messages" ] if m != None ]
60
+
39
61
# Add OpenAI's recommended function message
40
62
# request_params["messages"][0][
41
63
# "content"
@@ -46,6 +68,9 @@ def run_tool_calling_llm(llm, request_params):
46
68
accumulated_deltas = {}
47
69
language = None
48
70
code = ""
71
+ function_call_detected = False
72
+ accumulated_review = ""
73
+ review_category = None
49
74
50
75
for chunk in llm .completions (** request_params ):
51
76
if "choices" not in chunk or len (chunk ["choices" ]) == 0 :
@@ -55,18 +80,57 @@ def run_tool_calling_llm(llm, request_params):
55
80
delta = chunk ["choices" ][0 ]["delta" ]
56
81
57
82
# Convert tool call into function call, which we have great parsing logic for below
58
- if "tool_calls" in delta :
59
- if (
60
- len (delta ["tool_calls" ]) > 0
61
- and "function_call" in delta ["tool_calls" ][0 ]
62
- ):
63
- delta ["function_call" ] = delta ["tool_calls" ][0 ]["function_call" ]
83
+ if "tool_calls" in delta and delta ["tool_calls" ]:
84
+ function_call_detected = True
85
+
86
+ # import pdb; pdb.set_trace()
87
+ if len (delta ["tool_calls" ]) > 0 and delta ["tool_calls" ][0 ].function :
88
+ delta = {
89
+ # "id": delta["tool_calls"][0],
90
+ "function_call" : {
91
+ "name" : delta ["tool_calls" ][0 ].function .name ,
92
+ "arguments" : delta ["tool_calls" ][0 ].function .arguments ,
93
+ }
94
+ }
64
95
65
96
# Accumulate deltas
66
97
accumulated_deltas = merge_deltas (accumulated_deltas , delta )
67
98
68
99
if "content" in delta and delta ["content" ]:
69
- yield {"type" : "message" , "content" : delta ["content" ]}
100
+ if function_call_detected :
101
+ # More content after a code block? This is a code review by a judge layer.
102
+
103
+ # print("Code safety review:", delta["content"])
104
+
105
+ if review_category == None :
106
+ accumulated_review += delta ["content" ]
107
+
108
+ if "<unsafe>" in accumulated_review :
109
+ review_category = "unsafe"
110
+ if "<warning>" in accumulated_review :
111
+ review_category = "warning"
112
+ if "<safe>" in accumulated_review :
113
+ review_category = "safe"
114
+
115
+ if review_category != None :
116
+ for tag in [
117
+ "<safe>" ,
118
+ "</safe>" ,
119
+ "<warning>" ,
120
+ "</warning>" ,
121
+ "<unsafe>" ,
122
+ "</unsafe>" ,
123
+ ]:
124
+ delta ["content" ] = delta ["content" ].replace (tag , "" )
125
+
126
+ yield {
127
+ "type" : "review" ,
128
+ "format" : review_category ,
129
+ "content" : delta ["content" ],
130
+ }
131
+
132
+ else :
133
+ yield {"type" : "message" , "content" : delta ["content" ]}
70
134
71
135
if (
72
136
accumulated_deltas .get ("function_call" )
0 commit comments