moly_kit/clients/
deep_inquire.rs

1use crate::protocol::Tool;
2use crate::utils::asynchronous::{BoxPlatformSendFuture, BoxPlatformSendStream};
3use crate::utils::errors::enrich_http_error;
4use crate::{protocol::*, utils::sse::parse_sse};
5use async_stream::stream;
6use makepad_widgets::*;
7use makepad_widgets::{Cx, LiveNew, WidgetRef};
8use reqwest::header::{HeaderMap, HeaderName};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::{
12    str::FromStr,
13    sync::{Arc, RwLock},
14};
15use widgets::deep_inquire_content::DeepInquireContentWidgetRefExt;
16
17pub(crate) mod widgets;
18
19/// Article reference in a DeepInquire response
20#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
21pub struct Article {
22    pub title: String,
23    pub url: String,
24    pub snippet: String,
25    pub source: String,
26    pub relevance: usize,
27}
28
29/// A message being sent to the DeepInquire API
30#[derive(Clone, Debug, Serialize)]
31struct OutcomingMessage {
32    pub content: String,
33    pub role: Role,
34}
35
36impl TryFrom<Message> for OutcomingMessage {
37    type Error = ();
38
39    fn try_from(message: Message) -> Result<Self, Self::Error> {
40        let role = match message.from {
41            EntityId::User => Ok(Role::User),
42            EntityId::System => Ok(Role::System),
43            EntityId::Bot(_) => Ok(Role::Assistant),
44            EntityId::Tool => Err(()), // DeepInquire doesn't support tool role
45            EntityId::App => Err(()),
46        }?;
47
48        Ok(Self {
49            content: message.content.text,
50            role,
51        })
52    }
53}
54
55/// Role of a message in the DeepInquire API
56#[derive(Clone, Debug, Serialize, Deserialize)]
57enum Role {
58    #[serde(rename = "system")]
59    System,
60    #[serde(rename = "user")]
61    User,
62    #[serde(rename = "assistant")]
63    Assistant,
64}
65
66/// The delta content as part of a streaming response
67#[derive(Clone, Debug, Deserialize)]
68struct DeltaContent {
69    content: String,
70    #[serde(default)]
71    articles: Vec<Article>,
72    metadata: Metadata,
73    #[serde(default)]
74    r#type: String,
75    id: String,
76}
77
78#[derive(Clone, Debug, Deserialize)]
79struct Metadata {
80    #[serde(default)]
81    stage: String,
82}
83/// The Choice object in a streaming response
84#[derive(Clone, Debug, Deserialize)]
85#[allow(dead_code)]
86struct DeltaChoice {
87    pub delta: DeltaContent,
88    index: usize,
89    finish_reason: Option<String>,
90}
91
92/// Response from the DeepInquire API
93#[derive(Clone, Debug, Deserialize)]
94struct DeepInquireResponse {
95    choices: Vec<DeltaChoice>,
96}
97
98#[derive(Clone, Debug, Deserialize, Serialize)]
99pub struct Stage {
100    pub id: String,
101    pub citations: Vec<Article>,
102    pub substages: Vec<SubStage>,
103    pub stage_type: StageType,
104}
105
106#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Default, Live, LiveHook, LiveRead)]
107pub enum StageType {
108    #[default]
109    Thinking,
110    #[pick]
111    Content,
112    Completion,
113}
114
115impl FromStr for StageType {
116    type Err = ();
117
118    fn from_str(s: &str) -> Result<Self, Self::Err> {
119        match s {
120            "thinking" => Ok(StageType::Thinking),
121            "content" => Ok(StageType::Content),
122            "completion" => Ok(StageType::Completion),
123            _ => Err(()),
124        }
125    }
126}
127
128#[derive(Clone, Debug, Deserialize, Serialize, Default)]
129pub struct Data {
130    stages: Vec<Stage>,
131}
132
133#[derive(Clone, Debug, Deserialize, Serialize, Default)]
134pub struct SubStage {
135    pub id: String,
136    pub text: String,
137    pub name: String,
138}
139
140#[derive(Clone, Debug)]
141struct DeepInquireClientInner {
142    url: String,
143    headers: HeaderMap,
144    client: reqwest::Client,
145}
146
147/// A client for interacting with the DeepInquire API
148#[derive(Debug)]
149pub struct DeepInquireClient(Arc<RwLock<DeepInquireClientInner>>);
150
151impl Clone for DeepInquireClient {
152    fn clone(&self) -> Self {
153        Self(self.0.clone())
154    }
155}
156
157impl From<DeepInquireClientInner> for DeepInquireClient {
158    fn from(inner: DeepInquireClientInner) -> Self {
159        Self(Arc::new(RwLock::new(inner)))
160    }
161}
162
163impl DeepInquireClient {
164    /// Creates a new client with the given DeepInquire API URL
165    pub fn new(url: String) -> Self {
166        let headers = HeaderMap::new();
167        let client = default_client();
168
169        DeepInquireClientInner {
170            url,
171            headers,
172            client,
173        }
174        .into()
175    }
176
177    pub fn set_header(&mut self, key: &str, value: &str) -> Result<(), &'static str> {
178        let header_name = HeaderName::from_str(key).map_err(|_| "Invalid header name")?;
179
180        let header_value = value.parse().map_err(|_| "Invalid header value")?;
181
182        self.0
183            .write()
184            .unwrap()
185            .headers
186            .insert(header_name, header_value);
187
188        Ok(())
189    }
190
191    pub fn set_key(&mut self, key: &str) -> Result<(), &'static str> {
192        self.set_header("Authorization", &format!("Bearer {}", key))
193    }
194}
195
196impl BotClient for DeepInquireClient {
197    fn bots(&self) -> BoxPlatformSendFuture<'static, ClientResult<Vec<Bot>>> {
198        let inner = self.0.read().unwrap().clone();
199
200        // For now we return a hardcoded bot because DeepInquire does not support a /models endpoint
201        let bot = Bot {
202            id: BotId::new("DeepInquire", &inner.url),
203            name: "DeepInquire".to_string(),
204            avatar: Picture::Grapheme("D".into()),
205            capabilities: BotCapabilities::new().with_capability(BotCapability::Attachments),
206        };
207
208        let future = async move { ClientResult::new_ok(vec![bot]) };
209
210        Box::pin(future)
211    }
212
213    fn clone_box(&self) -> Box<dyn BotClient> {
214        Box::new(self.clone())
215    }
216
217    fn send(
218        &mut self,
219        bot_id: &BotId,
220        messages: &[Message],
221        _tools: &[Tool],
222    ) -> BoxPlatformSendStream<'static, ClientResult<MessageContent>> {
223        let inner = self.0.read().unwrap().clone();
224
225        let url = format!("{}/chat/completions", inner.url);
226        let headers = inner.headers;
227
228        let moly_messages: Vec<OutcomingMessage> = messages
229            .iter()
230            .filter_map(|m| m.clone().try_into().ok())
231            .collect();
232
233        let request = inner
234            .client
235            .post(&url)
236            .headers(headers)
237            .json(&serde_json::json!({
238                "model": bot_id.id(),
239                "messages": moly_messages,
240                "stream": true
241            }));
242
243        let stream = stream! {
244            let response = match request.send().await {
245                Ok(response) => {
246                    if response.status().is_success() {
247                        response
248                    } else {
249                        let status_code = response.status();
250                        let body = response.text().await.unwrap();
251                        let original = format!("Request failed with status {}", status_code);
252                        let enriched = enrich_http_error(status_code, &original, Some(&body));
253
254                        yield ClientError::new(
255                            ClientErrorKind::Response,
256                            enriched,
257                        ).into();
258                        return;
259                    }
260                }
261                Err(error) => {
262                    ::log::error!("SSE stream unexpectedly interrupted while reading from {}: {:?}", url, error);
263                    yield ClientError::new_with_source(
264                        ClientErrorKind::Network,
265                        format!("The connection was unexpectedly closed while streaming the response from {url}. This could be due to network issues, server problems, or timeouts."),
266                        Some(error),
267                    ).into();
268                    return;
269                }
270            };
271
272            let events = parse_sse(response.bytes_stream());
273            let mut content = MessageContent::default();
274            let mut consecutive_timeouts = 0;
275            let max_consecutive_timeouts = 3;
276            let mut message_count = 0;
277            // Only yield to UI every 10 messages to reduce back-pressure
278            let yield_frequency = 10;
279
280            for await event in events {
281                let event = match event {
282                    Ok(chunk) => {
283                        consecutive_timeouts = 0; // Reset timeout counter on success
284                        message_count += 1;
285                        chunk
286                    },
287                    Err(error) => {
288                        if error.is_timeout() {
289                            consecutive_timeouts += 1;
290
291                            if consecutive_timeouts >= max_consecutive_timeouts {
292                                ::log::error!("SSE stream error while reading from {}: {:?}", url, error);
293                                yield ClientError::new_with_source(
294                                    ClientErrorKind::Network,
295                                    format!("Too many consecutive timeouts ({}) while reading from {url}. Giving up.", consecutive_timeouts),
296                                    Some(error),
297                                ).into();
298                                return;
299                            }
300
301                            continue;
302                        } else {
303                            ::log::error!("SSE stream error while reading from {}: {:?}", url, error);
304                            yield ClientError::new_with_source(
305                                ClientErrorKind::Network,
306                                format!("The connection was unexpectedly closed while streaming the response from {url}. This could be due to network issues, server problems, or timeouts."),
307                                Some(error),
308                            ).into();
309                            return;
310                        }
311                    }
312                };
313
314                let response: DeepInquireResponse = match serde_json::from_str(&event) {
315                    Ok(c) => c,
316                    Err(error) => {
317                        ::log::error!("Could not parse the SSE message from {url} as JSON or its structure does not match the expected format. {}", error);
318                        yield ClientError::new_with_source(
319                            ClientErrorKind::Format,
320                            format!("Could not parse the SSE message from {url} as JSON or its structure does not match the expected format."),
321                            Some(error),
322                        ).into();
323                        return;
324                    }
325                };
326
327                apply_response_to_content(response, &mut content);
328
329                // Only yield to UI periodically to reduce back-pressure
330                // The first 20 messages are yielded immediately to ensure the UI is updated
331                if message_count % yield_frequency == 0 || message_count < 20 {
332                    yield ClientResult::new_ok(content.clone());
333                }
334            }
335
336            // Final yield to ensure the last state is captured
337            yield ClientResult::new_ok(content.clone());
338        };
339
340        Box::pin(stream)
341    }
342
343    fn content_widget(
344        &mut self,
345        cx: &mut Cx,
346        previous_widget: WidgetRef,
347        templates: &HashMap<LiveId, LivePtr>,
348        content: &MessageContent,
349    ) -> Option<WidgetRef> {
350        let Some(template) = templates.get(&live_id!(DeepInquireContent)).copied() else {
351            return None;
352        };
353
354        let Some(data) = content.data.as_deref() else {
355            return None;
356        };
357
358        let Ok(_) = serde_json::from_str::<Data>(data) else {
359            return None;
360        };
361
362        let widget = if previous_widget.as_deep_inquire_content().borrow().is_some() {
363            previous_widget
364        } else {
365            WidgetRef::new_from_ptr(cx, Some(template))
366        };
367
368        widget
369            .as_deep_inquire_content()
370            .borrow_mut()
371            .unwrap()
372            .set_content(cx, content);
373
374        Some(widget)
375    }
376}
377
378fn apply_response_to_content(response: DeepInquireResponse, content: &mut MessageContent) {
379    for choice in response.choices {
380        let delta = choice.delta;
381
382        // The id from the server follows the format <substage_id>.<stream_chunk_id>, which is not useful for tracking
383        // the substages. Thefore we use the metadata 'stage' (a name for the substage) as the id to use for keeping track and upadting each substage.
384        // For the stage_id we use the id from the delta without the stream_chunk_id.
385        let stage_id = delta.id.split('.').next().unwrap_or(&delta.id).to_string();
386        let substage_id = delta.metadata.stage.clone();
387        let stage_type = StageType::from_str(&delta.r#type).unwrap(); // TODO: Handle this gracefully
388
389        create_or_update_stage(content, stage_type, stage_id, move |existing_stage| {
390            // Check if the substage arriving in the response is already present in the accumulated content
391            let existing_substage = existing_stage
392                .substages
393                .iter_mut()
394                .find(|s| s.name == delta.metadata.stage);
395            if let Some(existing_substage) = existing_substage {
396                // SubStage exists, apply delta
397                existing_substage.text.push_str(&delta.content);
398            } else {
399                // SubStage does not exist, create a new one
400                existing_stage.substages.push(SubStage {
401                    id: substage_id,
402                    name: delta.metadata.stage,
403                    text: delta.content,
404                })
405            }
406
407            // Citations are at the Stage level, extend without duplicating
408            let new_citations: Vec<_> = delta
409                .articles
410                .into_iter()
411                .filter(|citation| !existing_stage.citations.contains(citation))
412                .collect();
413
414            existing_stage.citations.extend(new_citations);
415
416            return;
417        });
418    }
419}
420
421fn create_or_update_stage(
422    content: &mut MessageContent,
423    stage_type: StageType,
424    stage_id: String,
425    update_fn: impl FnOnce(&mut Stage),
426) {
427    let mut data: Data = content
428        .data
429        .as_ref()
430        .and_then(|d| serde_json::from_str(d).ok())
431        .unwrap_or_default();
432
433    // Find the existing stage by matching the enum variant
434    if let Some(mut existing_stage) = data.stages.iter_mut().find(|s| s.stage_type == stage_type) {
435        update_fn(&mut existing_stage);
436    } else {
437        let mut new_stage = Stage {
438            id: stage_id,
439            substages: vec![],
440            citations: vec![],
441            stage_type,
442        };
443
444        update_fn(&mut new_stage);
445        data.stages.push(new_stage);
446    }
447
448    content.data = Some(serde_json::to_string(&data).unwrap());
449}
450
451#[cfg(not(target_arch = "wasm32"))]
452fn default_client() -> reqwest::Client {
453    use std::time::Duration;
454
455    reqwest::Client::builder()
456        // Only considered while establishing the connection
457        .connect_timeout(Duration::from_secs(360))
458        // Keep high read timeout for word-by-word streaming
459        .read_timeout(Duration::from_secs(360))
460        .build()
461        .unwrap()
462}
463
464#[cfg(target_arch = "wasm32")]
465fn default_client() -> reqwest::Client {
466    // On web, reqwest timeouts are not configurable, but it uses the browser's
467    // fetch API under the hood, which handles connection issues properly.
468    reqwest::Client::new()
469}