1use crate::protocol::Tool;
2use async_stream::stream;
3use makepad_widgets::*;
4use reqwest::header::{HeaderMap, HeaderName};
5use serde::{Deserialize, Serialize};
6use std::{
7 collections::HashMap,
8 str::FromStr,
9 sync::{Arc, RwLock},
10};
11
12use crate::utils::asynchronous::{BoxPlatformSendFuture, BoxPlatformSendStream};
13use crate::utils::{serde::deserialize_null_default, sse::parse_sse};
14use crate::{protocol::*, utils::errors::enrich_http_error};
15
16#[derive(Clone, Debug, Deserialize, PartialEq)]
18struct Model {
19 id: String,
20}
21
22#[derive(Clone, Debug, Deserialize, PartialEq)]
24struct Models {
25 pub data: Vec<Model>,
26}
27
28#[derive(Serialize, Deserialize, Debug, Clone)]
30struct ImageUrlDetail {
31 url: String,
32 }
35
36#[derive(Serialize, Deserialize, Debug, Clone)]
38struct File {
39 filename: String,
40 file_data: String,
41}
42
43#[derive(Serialize, Deserialize, Debug, Clone)]
45#[serde(tag = "type")]
46#[serde(rename_all = "snake_case")]
47enum ContentPart {
48 Text { text: String },
49 ImageUrl { image_url: ImageUrlDetail },
50 File { file: File },
51}
52
53#[derive(Serialize, Deserialize, Debug, Clone)]
55#[serde(untagged)] enum Content {
57 Text(String),
58 Parts(Vec<ContentPart>),
59}
60
61impl Default for Content {
62 fn default() -> Self {
63 Content::Text(String::new())
64 }
65}
66
67impl Content {
68 pub fn text(&self) -> String {
70 match self {
71 Content::Text(text) => text.clone(),
72 Content::Parts(parts) => parts
73 .iter()
74 .filter_map(|part| match part {
75 ContentPart::Text { text } => Some(text.clone()),
76 _ => None,
77 })
78 .collect::<Vec<String>>()
79 .join(" "),
80 }
81 }
82}
83
84#[derive(Serialize)]
85struct FunctionDefinition {
86 name: String,
87 description: String,
88 parameters: serde_json::Value,
89 #[serde(skip_serializing_if = "Option::is_none")]
90 strict: Option<bool>,
91}
92
93#[derive(Serialize)]
95struct FunctionTool {
96 #[serde(rename = "type")]
97 tool_type: String,
98 function: FunctionDefinition,
99}
100
101impl From<&Tool> for FunctionTool {
102 fn from(tool: &Tool) -> Self {
103 let mut parameters_map = (*tool.input_schema).clone();
105
106 parameters_map.insert(
108 "additionalProperties".to_string(),
109 serde_json::Value::Bool(false),
110 );
111
112 if parameters_map.get("type") == Some(&serde_json::Value::String("object".to_string())) {
114 if !parameters_map.contains_key("properties") {
115 parameters_map.insert(
116 "properties".to_string(),
117 serde_json::Value::Object(serde_json::Map::new()),
118 );
119 }
120 }
121
122 let parameters = serde_json::Value::Object(parameters_map);
123
124 FunctionTool {
125 tool_type: "function".to_string(),
126 function: FunctionDefinition {
127 name: tool.name.clone(),
128 description: tool.description.as_deref().unwrap_or("").to_string(),
129 parameters,
130 strict: Some(false),
131 },
132 }
133 }
134}
135
136#[derive(Clone, Debug, Deserialize)]
138struct OpenAIToolCall {
139 #[serde(default)]
140 pub id: String,
141 #[serde(rename = "type")]
142 #[serde(default)]
143 #[allow(dead_code)] pub tool_type: String,
145 pub function: OpenAIFunctionCall,
146}
147
148#[derive(Clone, Debug, Deserialize)]
150struct OpenAIFunctionCall {
151 #[serde(default)]
152 pub name: String,
153 #[serde(default)]
154 pub arguments: String, }
156
157#[derive(Clone, Debug, Deserialize)]
167struct IncomingMessage {
168 #[serde(default)]
169 #[serde(deserialize_with = "deserialize_null_default")]
170 pub content: Content,
171 #[serde(default)]
178 #[serde(deserialize_with = "deserialize_null_default")]
179 #[serde(alias = "reasoning_content")]
180 pub reasoning: String,
181 #[serde(default)]
183 pub tool_calls: Vec<OpenAIToolCall>,
184}
185#[derive(Clone, Debug, Serialize)]
187struct OutgoingMessage {
188 pub content: Content,
189 pub role: Role,
190 #[serde(skip_serializing_if = "Option::is_none")]
191 pub tool_calls: Option<Vec<serde_json::Value>>,
192 #[serde(skip_serializing_if = "Option::is_none")]
193 pub tool_call_id: Option<String>,
194}
195
196async fn to_outgoing_message(message: Message) -> Result<OutgoingMessage, String> {
197 if !message.content.tool_results.is_empty() {
199 return outgoing_tool_result_message(message);
200 }
201
202 let role = match message.from {
203 EntityId::User => Ok(Role::User),
204 EntityId::System => Ok(Role::System),
205 EntityId::Bot(_) => Ok(Role::Assistant),
206 EntityId::Tool => Ok(Role::Tool),
207 EntityId::App => Err("App messages cannot be sent to OpenAI".to_string()),
208 }?;
209
210 let content = if message.content.attachments.is_empty() {
211 Content::Text(message.content.text)
212 } else {
213 let mut parts = Vec::new();
214
215 for attachment in message.content.attachments {
216 if !attachment.is_available() {
217 makepad_widgets::warning!("Skipping unavailable attachment: {}", attachment.name);
218 continue;
219 }
220
221 let content = attachment
222 .read_base64()
223 .await
224 .map_err(|e| format!("Failed to read attachment '{}': {}", attachment.name, e))?;
225 let data_url = format!(
226 "data:{};base64,{}",
227 attachment
228 .content_type
229 .as_deref()
230 .unwrap_or("application/octet-stream"),
231 content
232 );
233
234 if attachment.is_image() {
235 parts.push(ContentPart::ImageUrl {
236 image_url: ImageUrlDetail { url: data_url },
237 });
238 } else if attachment.is_pdf() {
239 parts.push(ContentPart::File {
240 file: File {
241 filename: attachment.name,
242 file_data: data_url,
243 },
244 });
245 } else {
246 match decode_base64_to_text(&content) {
248 Ok(text_content) => {
249 parts.push(ContentPart::Text {
250 text: format!("[File: {}]\n{}", attachment.name, text_content),
251 });
252 }
253 Err(_) => {
254 return Err(format!(
256 "File '{}' is not supported. Only images, PDFs, and text files can be sent through the Chat Completions API.",
257 attachment.name
258 ));
259 }
260 }
261 }
262 }
263
264 parts.push(ContentPart::Text {
265 text: message.content.text,
266 });
267 Content::Parts(parts)
268 };
269
270 let tool_calls =
272 if !message.content.tool_calls.is_empty() {
273 Some(message.content.tool_calls.iter().map(|tc| {
274 serde_json::json!({
275 "id": tc.id,
276 "type": "function",
277 "function": {
278 "name": tc.name,
279 "arguments": serde_json::to_string(&tc.arguments).unwrap_or_default()
280 }
281 })
282 }).collect())
283 } else {
284 None
285 };
286
287 Ok(OutgoingMessage {
288 content,
289 role,
290 tool_calls,
291 tool_call_id: None,
292 })
293}
294
295fn outgoing_tool_result_message(message: Message) -> Result<OutgoingMessage, String> {
299 let role = Role::Tool;
300 let content = Content::Text(
301 message
302 .content
303 .tool_results
304 .iter()
305 .map(|result| truncate_tool_result(&result.content))
306 .collect::<Vec<_>>()
307 .join("\n"),
308 );
309
310 let tool_call_id = message
311 .content
312 .tool_results
313 .first()
314 .map(|r| r.tool_call_id.clone());
315
316 return Ok(OutgoingMessage {
317 content,
318 role,
319 tool_calls: None,
320 tool_call_id,
321 });
322}
323
324fn truncate_tool_result(content: &str) -> String {
325 const MAX_TOOL_OUTPUT_CHARS: usize = 16384; if content.len() > MAX_TOOL_OUTPUT_CHARS {
327 let truncated = content
328 .chars()
329 .take(MAX_TOOL_OUTPUT_CHARS)
330 .collect::<String>();
331 format!("{}... [truncated]", truncated)
332 } else {
333 content.to_string()
334 }
335}
336
337fn decode_base64_to_text(base64: &str) -> Result<String, ()> {
340 use base64::Engine;
341
342 let bytes = base64::engine::general_purpose::STANDARD
344 .decode(base64)
345 .map_err(|_| ())?;
346
347 String::from_utf8(bytes).map_err(|_| ())
349}
350
351fn finalize_remaining_tool_calls(
354 content: &mut MessageContent,
355 tool_argument_buffers: &mut HashMap<String, String>,
356 tool_names: &mut HashMap<String, String>,
357 tool_call_ids_by_index: &mut HashMap<usize, String>,
358) {
359 for (tool_call_id, buffered_args) in tool_argument_buffers.drain() {
361 let arguments = if buffered_args.is_empty() || buffered_args == "{}" {
362 serde_json::Map::new()
363 } else {
364 match serde_json::from_str::<serde_json::Value>(&buffered_args) {
365 Ok(serde_json::Value::Object(args)) => args,
366 Ok(serde_json::Value::Null) => serde_json::Map::new(),
367 Ok(_) => serde_json::Map::new(),
368 Err(_) => serde_json::Map::new(),
369 }
370 };
371
372 if let Some(name) = tool_names.get(&tool_call_id) {
374 let tool_call = ToolCall {
375 id: tool_call_id.clone(),
376 name: name.clone(),
377 arguments,
378 ..Default::default()
379 };
380 content.tool_calls.push(tool_call);
381 }
382 }
383
384 tool_names.clear();
386 tool_call_ids_by_index.clear();
387}
388
389#[derive(Clone, Debug, Serialize, Deserialize)]
391enum Role {
392 #[serde(rename = "system")]
396 System,
397 #[serde(rename = "user")]
398 User,
399 #[serde(rename = "assistant")]
400 Assistant,
401 #[serde(rename = "tool")]
402 Tool,
403}
404
405#[derive(Clone, Debug, Deserialize)]
407struct Choice {
408 pub delta: IncomingMessage,
409 pub finish_reason: Option<String>,
410}
411
412#[derive(Clone, Debug, Deserialize)]
414struct Completion {
415 pub choices: Vec<Choice>,
416 #[serde(default)]
417 pub citations: Vec<String>,
418}
419
420#[derive(Clone, Debug)]
421struct OpenAIClientInner {
422 url: String,
423 headers: HeaderMap,
424 client: reqwest::Client,
425 tools_enabled: bool,
426}
427
428#[derive(Debug)]
430pub struct OpenAIClient(Arc<RwLock<OpenAIClientInner>>);
431
432impl Clone for OpenAIClient {
433 fn clone(&self) -> Self {
434 Self(self.0.clone())
435 }
436}
437
438impl From<OpenAIClientInner> for OpenAIClient {
439 fn from(inner: OpenAIClientInner) -> Self {
440 Self(Arc::new(RwLock::new(inner)))
441 }
442}
443
444impl OpenAIClient {
445 pub fn new(url: String) -> Self {
447 let headers = HeaderMap::new();
448 let client = default_client();
449
450 OpenAIClientInner {
451 url,
452 headers,
453 client,
454 tools_enabled: true, }
456 .into()
457 }
458
459 pub fn set_header(&mut self, key: &str, value: &str) -> Result<(), &'static str> {
460 let header_name = HeaderName::from_str(key).map_err(|_| "Invalid header name")?;
461
462 let header_value = value.parse().map_err(|_| "Invalid header value")?;
463
464 self.0
465 .write()
466 .unwrap()
467 .headers
468 .insert(header_name, header_value);
469
470 Ok(())
471 }
472
473 pub fn set_key(&mut self, key: &str) -> Result<(), &'static str> {
474 self.set_header("Authorization", &format!("Bearer {}", key))?;
475
476 if self.0.read().unwrap().url.contains("anthropic") {
478 self.set_header("x-api-key", key)?;
479 self.set_header("anthropic-version", "2023-06-01")?;
482 }
483
484 Ok(())
485 }
486
487 pub fn set_tools_enabled(&mut self, enabled: bool) {
488 self.0.write().unwrap().tools_enabled = enabled;
489 }
490}
491
492impl BotClient for OpenAIClient {
493 fn bots(&self) -> BoxPlatformSendFuture<'static, ClientResult<Vec<Bot>>> {
494 let inner = self.0.read().unwrap().clone();
495
496 let url = format!("{}/models", inner.url);
497 let headers = inner.headers;
498
499 let request = inner.client.get(&url).headers(headers);
500
501 let future = async move {
502 let response = match request.send().await {
503 Ok(response) => response,
504 Err(error) => {
505 return ClientError::new_with_source(
506 ClientErrorKind::Network,
507 format!("An error ocurred sending a request to {url}."),
508 Some(error),
509 )
510 .into();
511 }
512 };
513
514 if !response.status().is_success() {
515 let code = response.status().as_u16();
516 return ClientError::new(
517 ClientErrorKind::Response,
518 format!("Got unexpected HTTP status code {code} from {url}."),
519 )
520 .into();
521 }
522
523 let text = match response.text().await {
524 Ok(text) => text,
525 Err(error) => {
526 return ClientError::new_with_source(
527 ClientErrorKind::Format,
528 format!("Could not parse the response from {url} as valid text."),
529 Some(error),
530 )
531 .into();
532 }
533 };
534
535 if text.is_empty() {
536 return ClientError::new(
537 ClientErrorKind::Format,
538 format!("The response from {url} is empty."),
539 )
540 .into();
541 }
542
543 let models: Models = match serde_json::from_str(&text) {
544 Ok(models) => models,
545 Err(error) => {
546 return ClientError::new_with_source(
547 ClientErrorKind::Format,
548 format!("Could not parse the response from {url} as JSON or its structure does not match the expected format."),
549 Some(error),
550 ).into();
551 }
552 };
553
554 let mut bots: Vec<Bot> = models
555 .data
556 .iter()
557 .map(|m| Bot {
558 id: BotId::new(&m.id, &inner.url),
559 name: m.id.clone(),
560 avatar: Picture::Grapheme(
561 m.id.chars().next().unwrap().to_string().to_uppercase(),
562 ),
563 capabilities: BotCapabilities::new()
566 .with_capability(BotCapability::Attachments),
567 })
568 .filter(|b| {
569 !b.id.id().starts_with("dall-e") && !b.id.id().starts_with("gpt-image")
571 })
572 .collect();
573
574 bots.sort_by(|a, b| a.name.cmp(&b.name));
575
576 ClientResult::new_ok(bots)
577 };
578
579 Box::pin(future)
580 }
581
582 fn clone_box(&self) -> Box<dyn BotClient> {
583 Box::new(self.clone())
584 }
585
586 fn send(
588 &mut self,
589 bot_id: &BotId,
590 messages: &[Message],
591 tools: &[Tool],
592 ) -> BoxPlatformSendStream<'static, ClientResult<MessageContent>> {
593 let bot_id = bot_id.clone();
594 let messages = messages.to_vec();
595
596 let inner = self.0.read().unwrap().clone();
597 let url = format!("{}/chat/completions", inner.url);
598 let headers = inner.headers;
599
600 let tools: Vec<FunctionTool> = if inner.tools_enabled {
602 tools.iter().map(|t| t.into()).collect()
603 } else {
604 Vec::new()
605 };
606
607 let stream = stream! {
608 let mut outgoing_messages: Vec<OutgoingMessage> = Vec::with_capacity(messages.len());
609 for message in messages {
610 match to_outgoing_message(message.clone()).await {
611 Ok(outgoing_message) => outgoing_messages.push(outgoing_message),
612 Err(err) => {
613 error!("Could not convert message to outgoing format: {}", err);
614 yield ClientError::new(
615 ClientErrorKind::Format,
616 err,
617 ).into();
618 return;
619 }
620 }
621 }
622
623 let mut json = serde_json::json!({
624 "model": bot_id.id(),
625 "messages": outgoing_messages,
626 "stream": true
629 });
630
631 if !tools.is_empty() {
633 json["tools"] = serde_json::json!(tools);
634 }
635
636
637 let request = inner
638 .client
639 .post(&url)
640 .headers(headers)
641 .json(&json);
642
643 let response = match request.send().await {
644 Ok(response) => {
645 if response.status().is_success() {
646 response
647 } else {
648 let status_code = response.status();
649 let body = response.text().await.unwrap();
650 let original = format!("Request failed with status {}", status_code);
651 let enriched = enrich_http_error(status_code, &original, Some(&body));
652
653 error!("Error sending request to {}: status {}", url, status_code);
654 yield ClientError::new(
655 ClientErrorKind::Response,
656 enriched,
657 ).into();
658 return;
659 }
660 }
661 Err(error) => {
662 error!("Error sending request to {}: {:?}", url, error);
663 yield ClientError::new_with_source(
664 ClientErrorKind::Network,
665 format!("Could not send request to {url}. Verify your connection and the server status."),
666 Some(error),
667 ).into();
668 return;
669 }
670 };
671
672 let mut content = MessageContent::default();
673 let mut full_text = String::default();
674 let mut tool_argument_buffers: HashMap<String, String> = HashMap::new();
675 let mut tool_names: HashMap<String, String> = HashMap::new();
676 let mut tool_call_ids_by_index: HashMap<usize, String> = HashMap::new();
677 let events = parse_sse(response.bytes_stream());
678
679 for await event in events {
680 let event = match event {
681 Ok(event) => event,
682 Err(error) => {
683 error!("Response streaming got interrupted while reading from {}: {:?}", url, error);
684 yield ClientError::new_with_source(
685 ClientErrorKind::Network,
686 format!("Response streaming got interrupted while reading from {url}. This may be a problem with your connection or the server."),
687 Some(error),
688 ).into();
689 return;
690 }
691 };
692
693 let completion: Completion = match serde_json::from_str(&event) {
694 Ok(c) => c,
695 Err(error) => {
696 error!("Could not parse the SSE message from {url} as JSON or its structure does not match the expected format. {}", error);
697 yield ClientError::new_with_source(
698 ClientErrorKind::Format,
699 format!("Could not parse the SSE message from {url} as JSON or its structure does not match the expected format."),
700 Some(error),
701 ).into();
702 return;
703 }
704 };
705
706 let is_tool_calls_finished = completion.choices.iter()
708 .any(|choice| choice.finish_reason.as_deref() == Some("tool_calls"));
709
710 let mut should_yield_content = true;
711
712 if is_tool_calls_finished {
713 finalize_remaining_tool_calls(
714 &mut content,
715 &mut tool_argument_buffers,
716 &mut tool_names,
717 &mut tool_call_ids_by_index,
718 );
719 } else if !tool_argument_buffers.is_empty() || !tool_names.is_empty() {
720 should_yield_content = false;
722 }
723
724 for choice in &completion.choices {
726 full_text.push_str(&choice.delta.content.text());
728
729 let (reasoning, text) = split_reasoning_tag(&full_text);
731
732 content.text = text.to_string();
734
735 if reasoning.is_empty() {
736 content.reasoning.push_str(&choice.delta.reasoning);
738 } else {
739 content.reasoning = reasoning.to_string();
741 }
742
743 for (index, tool_call) in choice.delta.tool_calls.iter().enumerate() {
745 let tool_call_id = if !tool_call.id.is_empty() {
747 tool_call_ids_by_index.insert(index, tool_call.id.clone());
749 tool_call.id.clone()
750 } else {
751 if let Some(existing_id) = tool_call_ids_by_index.get(&index) {
753 existing_id.clone()
754 } else {
755 continue;
756 }
757 };
758
759 let buffer_entry = tool_argument_buffers.entry(tool_call_id.clone()).or_default();
761 buffer_entry.push_str(&tool_call.function.arguments);
762
763 if !tool_call.function.name.is_empty() {
766 tool_names.insert(tool_call_id.clone(), tool_call.function.name.clone());
767 }
768
769 if !buffer_entry.is_empty() {
771 let arguments = if buffer_entry == "{}" {
773 Some(serde_json::Map::new())
776 } else {
777 match serde_json::from_str::<serde_json::Value>(buffer_entry) {
778 Ok(serde_json::Value::Object(args)) => Some(args),
782 Ok(serde_json::Value::Null) => Some(serde_json::Map::new()),
785 Ok(_) => Some(serde_json::Map::new()),
788 Err(_) => None,
792 }
793 };
794
795 if let (Some(arguments), Some(name)) = (arguments, tool_names.get(&tool_call_id)) {
797 let tool_call = ToolCall {
798 id: tool_call_id.clone(),
799 name: name.clone(),
800 arguments,
801 ..Default::default()
802 };
803 content.tool_calls.push(tool_call);
804 tool_argument_buffers.remove(&tool_call_id);
805 tool_names.remove(&tool_call_id);
806 }
807 }
808 }
809 }
810
811 for citation in completion.citations {
812 if !content.citations.contains(&citation) {
813 content.citations.push(citation.clone());
814 }
815 }
816
817 if should_yield_content {
818 yield ClientResult::new_ok(content.clone());
819 }
820 }
821 };
822
823 Box::pin(stream)
824 }
825}
826
827#[cfg(not(target_arch = "wasm32"))]
828fn default_client() -> reqwest::Client {
829 use std::time::Duration;
830
831 reqwest::Client::builder()
834 .connect_timeout(Duration::from_secs(90))
836 .read_timeout(Duration::from_secs(90))
842 .build()
843 .unwrap()
844}
845
846#[cfg(target_arch = "wasm32")]
847fn default_client() -> reqwest::Client {
848 reqwest::Client::new()
851}
852
853fn split_reasoning_tag(text: &str) -> (&str, &str) {
858 const START_TAG: &str = "<think>";
859 const END_TAG: &str = "</think>";
860
861 if let Some(text) = text.trim_start().strip_prefix(START_TAG) {
862 text.split_once(END_TAG).unwrap_or((text, ""))
863 } else {
864 ("", text)
865 }
866}