1use crate::protocol::Tool;
2use crate::utils::asynchronous::{BoxPlatformSendFuture, BoxPlatformSendStream};
3use crate::utils::errors::enrich_http_error;
4use crate::{protocol::*, utils::sse::parse_sse};
5use async_stream::stream;
6use makepad_widgets::*;
7use makepad_widgets::{Cx, LiveNew, WidgetRef};
8use reqwest::header::{HeaderMap, HeaderName};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::{
12 str::FromStr,
13 sync::{Arc, RwLock},
14};
15use widgets::deep_inquire_content::DeepInquireContentWidgetRefExt;
16
17pub(crate) mod widgets;
18
19#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
21pub struct Article {
22 pub title: String,
23 pub url: String,
24 pub snippet: String,
25 pub source: String,
26 pub relevance: usize,
27}
28
29#[derive(Clone, Debug, Serialize)]
31struct OutcomingMessage {
32 pub content: String,
33 pub role: Role,
34}
35
36impl TryFrom<Message> for OutcomingMessage {
37 type Error = ();
38
39 fn try_from(message: Message) -> Result<Self, Self::Error> {
40 let role = match message.from {
41 EntityId::User => Ok(Role::User),
42 EntityId::System => Ok(Role::System),
43 EntityId::Bot(_) => Ok(Role::Assistant),
44 EntityId::Tool => Err(()), EntityId::App => Err(()),
46 }?;
47
48 Ok(Self {
49 content: message.content.text,
50 role,
51 })
52 }
53}
54
55#[derive(Clone, Debug, Serialize, Deserialize)]
57enum Role {
58 #[serde(rename = "system")]
59 System,
60 #[serde(rename = "user")]
61 User,
62 #[serde(rename = "assistant")]
63 Assistant,
64}
65
66#[derive(Clone, Debug, Deserialize)]
68struct DeltaContent {
69 content: String,
70 #[serde(default)]
71 articles: Vec<Article>,
72 metadata: Metadata,
73 #[serde(default)]
74 r#type: String,
75 id: String,
76}
77
78#[derive(Clone, Debug, Deserialize)]
79struct Metadata {
80 #[serde(default)]
81 stage: String,
82}
83#[derive(Clone, Debug, Deserialize)]
85#[allow(dead_code)]
86struct DeltaChoice {
87 pub delta: DeltaContent,
88 index: usize,
89 finish_reason: Option<String>,
90}
91
92#[derive(Clone, Debug, Deserialize)]
94struct DeepInquireResponse {
95 choices: Vec<DeltaChoice>,
96}
97
98#[derive(Clone, Debug, Deserialize, Serialize)]
99pub struct Stage {
100 pub id: String,
101 pub citations: Vec<Article>,
102 pub substages: Vec<SubStage>,
103 pub stage_type: StageType,
104}
105
106#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Default, Live, LiveHook, LiveRead)]
107pub enum StageType {
108 #[default]
109 Thinking,
110 #[pick]
111 Content,
112 Completion,
113}
114
115impl FromStr for StageType {
116 type Err = ();
117
118 fn from_str(s: &str) -> Result<Self, Self::Err> {
119 match s {
120 "thinking" => Ok(StageType::Thinking),
121 "content" => Ok(StageType::Content),
122 "completion" => Ok(StageType::Completion),
123 _ => Err(()),
124 }
125 }
126}
127
128#[derive(Clone, Debug, Deserialize, Serialize, Default)]
129pub struct Data {
130 stages: Vec<Stage>,
131}
132
133#[derive(Clone, Debug, Deserialize, Serialize, Default)]
134pub struct SubStage {
135 pub id: String,
136 pub text: String,
137 pub name: String,
138}
139
140#[derive(Clone, Debug)]
141struct DeepInquireClientInner {
142 url: String,
143 headers: HeaderMap,
144 client: reqwest::Client,
145}
146
147#[derive(Debug)]
149pub struct DeepInquireClient(Arc<RwLock<DeepInquireClientInner>>);
150
151impl Clone for DeepInquireClient {
152 fn clone(&self) -> Self {
153 Self(self.0.clone())
154 }
155}
156
157impl From<DeepInquireClientInner> for DeepInquireClient {
158 fn from(inner: DeepInquireClientInner) -> Self {
159 Self(Arc::new(RwLock::new(inner)))
160 }
161}
162
163impl DeepInquireClient {
164 pub fn new(url: String) -> Self {
166 let headers = HeaderMap::new();
167 let client = default_client();
168
169 DeepInquireClientInner {
170 url,
171 headers,
172 client,
173 }
174 .into()
175 }
176
177 pub fn set_header(&mut self, key: &str, value: &str) -> Result<(), &'static str> {
178 let header_name = HeaderName::from_str(key).map_err(|_| "Invalid header name")?;
179
180 let header_value = value.parse().map_err(|_| "Invalid header value")?;
181
182 self.0
183 .write()
184 .unwrap()
185 .headers
186 .insert(header_name, header_value);
187
188 Ok(())
189 }
190
191 pub fn set_key(&mut self, key: &str) -> Result<(), &'static str> {
192 self.set_header("Authorization", &format!("Bearer {}", key))
193 }
194}
195
196impl BotClient for DeepInquireClient {
197 fn bots(&self) -> BoxPlatformSendFuture<'static, ClientResult<Vec<Bot>>> {
198 let inner = self.0.read().unwrap().clone();
199
200 let bot = Bot {
202 id: BotId::new("DeepInquire", &inner.url),
203 name: "DeepInquire".to_string(),
204 avatar: Picture::Grapheme("D".into()),
205 capabilities: BotCapabilities::new().with_capability(BotCapability::Attachments),
206 };
207
208 let future = async move { ClientResult::new_ok(vec![bot]) };
209
210 Box::pin(future)
211 }
212
213 fn clone_box(&self) -> Box<dyn BotClient> {
214 Box::new(self.clone())
215 }
216
217 fn send(
218 &mut self,
219 bot_id: &BotId,
220 messages: &[Message],
221 _tools: &[Tool],
222 ) -> BoxPlatformSendStream<'static, ClientResult<MessageContent>> {
223 let inner = self.0.read().unwrap().clone();
224
225 let url = format!("{}/chat/completions", inner.url);
226 let headers = inner.headers;
227
228 let moly_messages: Vec<OutcomingMessage> = messages
229 .iter()
230 .filter_map(|m| m.clone().try_into().ok())
231 .collect();
232
233 let request = inner
234 .client
235 .post(&url)
236 .headers(headers)
237 .json(&serde_json::json!({
238 "model": bot_id.id(),
239 "messages": moly_messages,
240 "stream": true
241 }));
242
243 let stream = stream! {
244 let response = match request.send().await {
245 Ok(response) => {
246 if response.status().is_success() {
247 response
248 } else {
249 let status_code = response.status();
250 let body = response.text().await.unwrap();
251 let original = format!("Request failed with status {}", status_code);
252 let enriched = enrich_http_error(status_code, &original, Some(&body));
253
254 yield ClientError::new(
255 ClientErrorKind::Response,
256 enriched,
257 ).into();
258 return;
259 }
260 }
261 Err(error) => {
262 ::log::error!("SSE stream unexpectedly interrupted while reading from {}: {:?}", url, error);
263 yield ClientError::new_with_source(
264 ClientErrorKind::Network,
265 format!("The connection was unexpectedly closed while streaming the response from {url}. This could be due to network issues, server problems, or timeouts."),
266 Some(error),
267 ).into();
268 return;
269 }
270 };
271
272 let events = parse_sse(response.bytes_stream());
273 let mut content = MessageContent::default();
274 let mut consecutive_timeouts = 0;
275 let max_consecutive_timeouts = 3;
276 let mut message_count = 0;
277 let yield_frequency = 10;
279
280 for await event in events {
281 let event = match event {
282 Ok(chunk) => {
283 consecutive_timeouts = 0; message_count += 1;
285 chunk
286 },
287 Err(error) => {
288 if error.is_timeout() {
289 consecutive_timeouts += 1;
290
291 if consecutive_timeouts >= max_consecutive_timeouts {
292 ::log::error!("SSE stream error while reading from {}: {:?}", url, error);
293 yield ClientError::new_with_source(
294 ClientErrorKind::Network,
295 format!("Too many consecutive timeouts ({}) while reading from {url}. Giving up.", consecutive_timeouts),
296 Some(error),
297 ).into();
298 return;
299 }
300
301 continue;
302 } else {
303 ::log::error!("SSE stream error while reading from {}: {:?}", url, error);
304 yield ClientError::new_with_source(
305 ClientErrorKind::Network,
306 format!("The connection was unexpectedly closed while streaming the response from {url}. This could be due to network issues, server problems, or timeouts."),
307 Some(error),
308 ).into();
309 return;
310 }
311 }
312 };
313
314 let response: DeepInquireResponse = match serde_json::from_str(&event) {
315 Ok(c) => c,
316 Err(error) => {
317 ::log::error!("Could not parse the SSE message from {url} as JSON or its structure does not match the expected format. {}", error);
318 yield ClientError::new_with_source(
319 ClientErrorKind::Format,
320 format!("Could not parse the SSE message from {url} as JSON or its structure does not match the expected format."),
321 Some(error),
322 ).into();
323 return;
324 }
325 };
326
327 apply_response_to_content(response, &mut content);
328
329 if message_count % yield_frequency == 0 || message_count < 20 {
332 yield ClientResult::new_ok(content.clone());
333 }
334 }
335
336 yield ClientResult::new_ok(content.clone());
338 };
339
340 Box::pin(stream)
341 }
342
343 fn content_widget(
344 &mut self,
345 cx: &mut Cx,
346 previous_widget: WidgetRef,
347 templates: &HashMap<LiveId, LivePtr>,
348 content: &MessageContent,
349 ) -> Option<WidgetRef> {
350 let Some(template) = templates.get(&live_id!(DeepInquireContent)).copied() else {
351 return None;
352 };
353
354 let Some(data) = content.data.as_deref() else {
355 return None;
356 };
357
358 let Ok(_) = serde_json::from_str::<Data>(data) else {
359 return None;
360 };
361
362 let widget = if previous_widget.as_deep_inquire_content().borrow().is_some() {
363 previous_widget
364 } else {
365 WidgetRef::new_from_ptr(cx, Some(template))
366 };
367
368 widget
369 .as_deep_inquire_content()
370 .borrow_mut()
371 .unwrap()
372 .set_content(cx, content);
373
374 Some(widget)
375 }
376}
377
378fn apply_response_to_content(response: DeepInquireResponse, content: &mut MessageContent) {
379 for choice in response.choices {
380 let delta = choice.delta;
381
382 let stage_id = delta.id.split('.').next().unwrap_or(&delta.id).to_string();
386 let substage_id = delta.metadata.stage.clone();
387 let stage_type = StageType::from_str(&delta.r#type).unwrap(); create_or_update_stage(content, stage_type, stage_id, move |existing_stage| {
390 let existing_substage = existing_stage
392 .substages
393 .iter_mut()
394 .find(|s| s.name == delta.metadata.stage);
395 if let Some(existing_substage) = existing_substage {
396 existing_substage.text.push_str(&delta.content);
398 } else {
399 existing_stage.substages.push(SubStage {
401 id: substage_id,
402 name: delta.metadata.stage,
403 text: delta.content,
404 })
405 }
406
407 let new_citations: Vec<_> = delta
409 .articles
410 .into_iter()
411 .filter(|citation| !existing_stage.citations.contains(citation))
412 .collect();
413
414 existing_stage.citations.extend(new_citations);
415
416 return;
417 });
418 }
419}
420
421fn create_or_update_stage(
422 content: &mut MessageContent,
423 stage_type: StageType,
424 stage_id: String,
425 update_fn: impl FnOnce(&mut Stage),
426) {
427 let mut data: Data = content
428 .data
429 .as_ref()
430 .and_then(|d| serde_json::from_str(d).ok())
431 .unwrap_or_default();
432
433 if let Some(mut existing_stage) = data.stages.iter_mut().find(|s| s.stage_type == stage_type) {
435 update_fn(&mut existing_stage);
436 } else {
437 let mut new_stage = Stage {
438 id: stage_id,
439 substages: vec![],
440 citations: vec![],
441 stage_type,
442 };
443
444 update_fn(&mut new_stage);
445 data.stages.push(new_stage);
446 }
447
448 content.data = Some(serde_json::to_string(&data).unwrap());
449}
450
451#[cfg(not(target_arch = "wasm32"))]
452fn default_client() -> reqwest::Client {
453 use std::time::Duration;
454
455 reqwest::Client::builder()
456 .connect_timeout(Duration::from_secs(360))
458 .read_timeout(Duration::from_secs(360))
460 .build()
461 .unwrap()
462}
463
464#[cfg(target_arch = "wasm32")]
465fn default_client() -> reqwest::Client {
466 reqwest::Client::new()
469}