diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index b5cf4a6..afe5642 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -1032,6 +1032,21 @@ dependencies = [ "new_debug_unreachable", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -1039,6 +1054,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -1106,6 +1122,7 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", "futures-io", "futures-macro", @@ -3920,6 +3937,8 @@ dependencies = [ name = "tauri-app" version = "0.1.0" dependencies = [ + "bytes", + "futures", "reqwest", "serde", "serde_json", diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 2210f88..3437c16 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -22,7 +22,9 @@ tauri = { version = "2", features = [] } tauri-plugin-opener = "2" serde = { version = "1", features = ["derive"] } serde_json = "1" -reqwest = { version = "0.12", features = ["json"] } +reqwest = { version = "0.12", features = ["json", "stream"] } tokio = { version = "1", features = ["full"] } uuid = { version = "1", features = ["v4"] } +futures = "0.3" +bytes = "1" diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index f95d38a..b6bbc1b 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -2,6 +2,8 @@ use serde::{Deserialize, Serialize}; use std::fs; use std::path::PathBuf; use uuid::Uuid; +use futures::StreamExt; +use tauri::Emitter; #[derive(Debug, Clone, Serialize, Deserialize)] struct ApiConfig { @@ -10,6 +12,8 @@ struct ApiConfig { model: String, #[serde(default)] active_character_id: Option, + #[serde(default)] + stream: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -66,6 +70,30 @@ struct ModelsResponse { data: Vec, } +#[derive(Debug, Serialize, Deserialize)] +struct StreamChatRequest { + model: String, + max_tokens: u32, + messages: Vec, + stream: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +struct StreamChoice { + delta: Delta, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Delta { + #[serde(default)] + content: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct StreamResponse { + choices: Vec, +} + fn get_config_path() -> PathBuf { let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string()); PathBuf::from(home).join(".config/claudia/config.json") @@ -228,7 +256,7 @@ async fn validate_api(base_url: String, api_key: String) -> Result, } #[tauri::command] -async fn save_api_config(base_url: String, api_key: String, model: String) -> Result<(), String> { +async fn save_api_config(base_url: String, api_key: String, model: String, stream: bool) -> Result<(), String> { // Preserve existing active_character_id if it exists let active_character_id = load_config().and_then(|c| c.active_character_id); @@ -237,6 +265,7 @@ async fn save_api_config(base_url: String, api_key: String, model: String) -> Re api_key, model, active_character_id, + stream, }; save_config(&config) } @@ -316,6 +345,107 @@ async fn chat(message: String) -> Result { Ok(assistant_message) } +#[tauri::command] +async fn chat_stream(app_handle: tauri::AppHandle, message: String) -> Result { + let config = load_config().ok_or_else(|| "API not configured".to_string())?; + let character = get_active_character(); + let mut history = load_history(&character.id); + + // Add user message to history + history.messages.push(Message { + role: "user".to_string(), + content: message.clone(), + }); + + let client = reqwest::Client::new(); + let base = config.base_url.trim_end_matches('/'); + let url = if base.ends_with("/v1") { + format!("{}/chat/completions", base) + } else { + format!("{}/v1/chat/completions", base) + }; + + // Build messages with system prompt first + let mut api_messages = vec![Message { + role: "system".to_string(), + content: character.system_prompt.clone(), + }]; + api_messages.extend(history.messages.clone()); + + let request = StreamChatRequest { + model: config.model.clone(), + max_tokens: 4096, + messages: api_messages, + stream: true, + }; + + let response = client + .post(&url) + .header("authorization", format!("Bearer {}", &config.api_key)) + .header("content-type", "application/json") + .json(&request) + .send() + .await + .map_err(|e| format!("Request failed: {}", e))?; + + if !response.status().is_success() { + return Err(format!("API error: {}", response.status())); + } + + // Process streaming response + let mut full_content = String::new(); + let mut stream = response.bytes_stream(); + + let mut buffer = String::new(); + while let Some(chunk_result) = stream.next().await { + let chunk = chunk_result.map_err(|e| format!("Stream error: {}", e))?; + let chunk_str = String::from_utf8_lossy(&chunk); + buffer.push_str(&chunk_str); + + // Process complete lines + while let Some(line_end) = buffer.find('\n') { + let line = buffer[..line_end].trim().to_string(); + buffer = buffer[line_end + 1..].to_string(); + + // Parse SSE data lines + if line.starts_with("data: ") { + let data = &line[6..]; + + // Check for stream end + if data == "[DONE]" { + break; + } + + // Parse JSON and extract content + if let Ok(stream_response) = serde_json::from_str::(data) { + if let Some(choice) = stream_response.choices.first() { + if let Some(content) = &choice.delta.content { + full_content.push_str(content); + + // Emit token to frontend + let _ = app_handle.emit_to("main", "chat-token", content.clone()); + } + } + } + } + } + } + + // Add assistant message to history + history.messages.push(Message { + role: "assistant".to_string(), + content: full_content.clone(), + }); + + // Save history + save_history(&character.id, &history).ok(); + + // Emit completion event + let _ = app_handle.emit_to("main", "chat-complete", ()); + + Ok(full_content) +} + #[tauri::command] fn get_chat_history() -> Result, String> { let character = get_active_character(); @@ -415,6 +545,7 @@ pub fn run() { .plugin(tauri_plugin_opener::init()) .invoke_handler(tauri::generate_handler![ chat, + chat_stream, validate_api, save_api_config, get_api_config, diff --git a/src/index.html b/src/index.html index d7b9cad..bcfe0e9 100644 --- a/src/index.html +++ b/src/index.html @@ -104,6 +104,13 @@ +
+ +
+