Merge remote-tracking branch 'origin/main' into fix/windows-key-permissions-warning
# Conflicts: # src/security/secrets.rs
This commit is contained in:
commit
6d68e89ef0
11 changed files with 390 additions and 103 deletions
|
|
@ -210,9 +210,7 @@ async fn get_max_rowid(db_path: &Path) -> anyhow::Result<i64> {
|
||||||
&path,
|
&path,
|
||||||
OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_NO_MUTEX,
|
OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_NO_MUTEX,
|
||||||
)?;
|
)?;
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare("SELECT MAX(ROWID) FROM message WHERE is_from_me = 0")?;
|
||||||
"SELECT MAX(ROWID) FROM message WHERE is_from_me = 0"
|
|
||||||
)?;
|
|
||||||
let rowid: Option<i64> = stmt.query_row([], |row| row.get(0))?;
|
let rowid: Option<i64> = stmt.query_row([], |row| row.get(0))?;
|
||||||
Ok(rowid.unwrap_or(0))
|
Ok(rowid.unwrap_or(0))
|
||||||
})
|
})
|
||||||
|
|
@ -228,31 +226,32 @@ async fn fetch_new_messages(
|
||||||
since_rowid: i64,
|
since_rowid: i64,
|
||||||
) -> anyhow::Result<Vec<(i64, String, String)>> {
|
) -> anyhow::Result<Vec<(i64, String, String)>> {
|
||||||
let path = db_path.to_path_buf();
|
let path = db_path.to_path_buf();
|
||||||
let results = tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<(i64, String, String)>> {
|
let results =
|
||||||
let conn = Connection::open_with_flags(
|
tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<(i64, String, String)>> {
|
||||||
&path,
|
let conn = Connection::open_with_flags(
|
||||||
OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_NO_MUTEX,
|
&path,
|
||||||
)?;
|
OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_NO_MUTEX,
|
||||||
let mut stmt = conn.prepare(
|
)?;
|
||||||
"SELECT m.ROWID, h.id, m.text \
|
let mut stmt = conn.prepare(
|
||||||
|
"SELECT m.ROWID, h.id, m.text \
|
||||||
FROM message m \
|
FROM message m \
|
||||||
JOIN handle h ON m.handle_id = h.ROWID \
|
JOIN handle h ON m.handle_id = h.ROWID \
|
||||||
WHERE m.ROWID > ?1 \
|
WHERE m.ROWID > ?1 \
|
||||||
AND m.is_from_me = 0 \
|
AND m.is_from_me = 0 \
|
||||||
AND m.text IS NOT NULL \
|
AND m.text IS NOT NULL \
|
||||||
ORDER BY m.ROWID ASC \
|
ORDER BY m.ROWID ASC \
|
||||||
LIMIT 20"
|
LIMIT 20",
|
||||||
)?;
|
)?;
|
||||||
let rows = stmt.query_map([since_rowid], |row| {
|
let rows = stmt.query_map([since_rowid], |row| {
|
||||||
Ok((
|
Ok((
|
||||||
row.get::<_, i64>(0)?,
|
row.get::<_, i64>(0)?,
|
||||||
row.get::<_, String>(1)?,
|
row.get::<_, String>(1)?,
|
||||||
row.get::<_, String>(2)?,
|
row.get::<_, String>(2)?,
|
||||||
))
|
))
|
||||||
})?;
|
})?;
|
||||||
rows.collect::<Result<Vec<_>, _>>().map_err(Into::into)
|
rows.collect::<Result<Vec<_>, _>>().map_err(Into::into)
|
||||||
})
|
})
|
||||||
.await??;
|
.await??;
|
||||||
Ok(results)
|
Ok(results)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -501,7 +500,7 @@ mod tests {
|
||||||
fn invalid_target_applescript_injection() {
|
fn invalid_target_applescript_injection() {
|
||||||
// Various injection attempts
|
// Various injection attempts
|
||||||
assert!(!is_valid_imessage_target(r#"test" & quit"#));
|
assert!(!is_valid_imessage_target(r#"test" & quit"#));
|
||||||
assert!(!is_valid_imessage_target(r#"test\ndo shell script"#));
|
assert!(!is_valid_imessage_target(r"test\ndo shell script"));
|
||||||
assert!(!is_valid_imessage_target("test\"; malicious code; \""));
|
assert!(!is_valid_imessage_target("test\"; malicious code; \""));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -536,9 +535,9 @@ mod tests {
|
||||||
fn create_test_db() -> (tempfile::TempDir, std::path::PathBuf) {
|
fn create_test_db() -> (tempfile::TempDir, std::path::PathBuf) {
|
||||||
let dir = tempfile::tempdir().unwrap();
|
let dir = tempfile::tempdir().unwrap();
|
||||||
let db_path = dir.path().join("chat.db");
|
let db_path = dir.path().join("chat.db");
|
||||||
|
|
||||||
let conn = Connection::open(&db_path).unwrap();
|
let conn = Connection::open(&db_path).unwrap();
|
||||||
|
|
||||||
// Create minimal schema matching macOS Messages.app
|
// Create minimal schema matching macOS Messages.app
|
||||||
conn.execute_batch(
|
conn.execute_batch(
|
||||||
"CREATE TABLE handle (
|
"CREATE TABLE handle (
|
||||||
|
|
@ -551,9 +550,10 @@ mod tests {
|
||||||
text TEXT,
|
text TEXT,
|
||||||
is_from_me INTEGER DEFAULT 0,
|
is_from_me INTEGER DEFAULT 0,
|
||||||
FOREIGN KEY (handle_id) REFERENCES handle(ROWID)
|
FOREIGN KEY (handle_id) REFERENCES handle(ROWID)
|
||||||
);"
|
);",
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
(dir, db_path)
|
(dir, db_path)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -569,11 +569,15 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn get_max_rowid_with_messages() {
|
async fn get_max_rowid_with_messages() {
|
||||||
let (_dir, db_path) = create_test_db();
|
let (_dir, db_path) = create_test_db();
|
||||||
|
|
||||||
// Insert test data
|
// Insert test data
|
||||||
{
|
{
|
||||||
let conn = Connection::open(&db_path).unwrap();
|
let conn = Connection::open(&db_path).unwrap();
|
||||||
conn.execute("INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')", []).unwrap();
|
conn.execute(
|
||||||
|
"INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (100, 1, 'Hello', 0)",
|
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (100, 1, 'Hello', 0)",
|
||||||
[]
|
[]
|
||||||
|
|
@ -588,7 +592,7 @@ mod tests {
|
||||||
[]
|
[]
|
||||||
).unwrap();
|
).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let result = get_max_rowid(&db_path).await.unwrap();
|
let result = get_max_rowid(&db_path).await.unwrap();
|
||||||
// Should return 200, not 300 (ignores is_from_me=1)
|
// Should return 200, not 300 (ignores is_from_me=1)
|
||||||
assert_eq!(result, 200);
|
assert_eq!(result, 200);
|
||||||
|
|
@ -612,12 +616,20 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fetch_new_messages_returns_correct_data() {
|
async fn fetch_new_messages_returns_correct_data() {
|
||||||
let (_dir, db_path) = create_test_db();
|
let (_dir, db_path) = create_test_db();
|
||||||
|
|
||||||
// Insert test data
|
// Insert test data
|
||||||
{
|
{
|
||||||
let conn = Connection::open(&db_path).unwrap();
|
let conn = Connection::open(&db_path).unwrap();
|
||||||
conn.execute("INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')", []).unwrap();
|
conn.execute(
|
||||||
conn.execute("INSERT INTO handle (ROWID, id) VALUES (2, 'user@example.com')", []).unwrap();
|
"INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO handle (ROWID, id) VALUES (2, 'user@example.com')",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'First message', 0)",
|
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'First message', 0)",
|
||||||
[]
|
[]
|
||||||
|
|
@ -627,21 +639,35 @@ mod tests {
|
||||||
[]
|
[]
|
||||||
).unwrap();
|
).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
||||||
assert_eq!(result.len(), 2);
|
assert_eq!(result.len(), 2);
|
||||||
assert_eq!(result[0], (10, "+1234567890".to_string(), "First message".to_string()));
|
assert_eq!(
|
||||||
assert_eq!(result[1], (20, "user@example.com".to_string(), "Second message".to_string()));
|
result[0],
|
||||||
|
(10, "+1234567890".to_string(), "First message".to_string())
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
result[1],
|
||||||
|
(
|
||||||
|
20,
|
||||||
|
"user@example.com".to_string(),
|
||||||
|
"Second message".to_string()
|
||||||
|
)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fetch_new_messages_filters_by_rowid() {
|
async fn fetch_new_messages_filters_by_rowid() {
|
||||||
let (_dir, db_path) = create_test_db();
|
let (_dir, db_path) = create_test_db();
|
||||||
|
|
||||||
// Insert test data
|
// Insert test data
|
||||||
{
|
{
|
||||||
let conn = Connection::open(&db_path).unwrap();
|
let conn = Connection::open(&db_path).unwrap();
|
||||||
conn.execute("INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')", []).unwrap();
|
conn.execute(
|
||||||
|
"INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'Old message', 0)",
|
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'Old message', 0)",
|
||||||
[]
|
[]
|
||||||
|
|
@ -651,7 +677,7 @@ mod tests {
|
||||||
[]
|
[]
|
||||||
).unwrap();
|
).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch only messages after ROWID 15
|
// Fetch only messages after ROWID 15
|
||||||
let result = fetch_new_messages(&db_path, 15).await.unwrap();
|
let result = fetch_new_messages(&db_path, 15).await.unwrap();
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
|
|
@ -662,11 +688,15 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fetch_new_messages_excludes_sent_messages() {
|
async fn fetch_new_messages_excludes_sent_messages() {
|
||||||
let (_dir, db_path) = create_test_db();
|
let (_dir, db_path) = create_test_db();
|
||||||
|
|
||||||
// Insert test data
|
// Insert test data
|
||||||
{
|
{
|
||||||
let conn = Connection::open(&db_path).unwrap();
|
let conn = Connection::open(&db_path).unwrap();
|
||||||
conn.execute("INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')", []).unwrap();
|
conn.execute(
|
||||||
|
"INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'Received', 0)",
|
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'Received', 0)",
|
||||||
[]
|
[]
|
||||||
|
|
@ -676,7 +706,7 @@ mod tests {
|
||||||
[]
|
[]
|
||||||
).unwrap();
|
).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
assert_eq!(result[0].2, "Received");
|
assert_eq!(result[0].2, "Received");
|
||||||
|
|
@ -685,21 +715,26 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fetch_new_messages_excludes_null_text() {
|
async fn fetch_new_messages_excludes_null_text() {
|
||||||
let (_dir, db_path) = create_test_db();
|
let (_dir, db_path) = create_test_db();
|
||||||
|
|
||||||
// Insert test data
|
// Insert test data
|
||||||
{
|
{
|
||||||
let conn = Connection::open(&db_path).unwrap();
|
let conn = Connection::open(&db_path).unwrap();
|
||||||
conn.execute("INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')", []).unwrap();
|
conn.execute(
|
||||||
|
"INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'Has text', 0)",
|
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'Has text', 0)",
|
||||||
[]
|
[]
|
||||||
).unwrap();
|
).unwrap();
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (20, 1, NULL, 0)",
|
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (20, 1, NULL, 0)",
|
||||||
[]
|
[],
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
assert_eq!(result[0].2, "Has text");
|
assert_eq!(result[0].2, "Has text");
|
||||||
|
|
@ -708,11 +743,15 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fetch_new_messages_respects_limit() {
|
async fn fetch_new_messages_respects_limit() {
|
||||||
let (_dir, db_path) = create_test_db();
|
let (_dir, db_path) = create_test_db();
|
||||||
|
|
||||||
// Insert 25 messages (limit is 20)
|
// Insert 25 messages (limit is 20)
|
||||||
{
|
{
|
||||||
let conn = Connection::open(&db_path).unwrap();
|
let conn = Connection::open(&db_path).unwrap();
|
||||||
conn.execute("INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')", []).unwrap();
|
conn.execute(
|
||||||
|
"INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
for i in 1..=25 {
|
for i in 1..=25 {
|
||||||
conn.execute(
|
conn.execute(
|
||||||
&format!("INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES ({i}, 1, 'Message {i}', 0)"),
|
&format!("INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES ({i}, 1, 'Message {i}', 0)"),
|
||||||
|
|
@ -720,7 +759,7 @@ mod tests {
|
||||||
).unwrap();
|
).unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
||||||
assert_eq!(result.len(), 20); // Limited to 20
|
assert_eq!(result.len(), 20); // Limited to 20
|
||||||
assert_eq!(result[0].0, 1); // First message
|
assert_eq!(result[0].0, 1); // First message
|
||||||
|
|
@ -730,11 +769,15 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fetch_new_messages_ordered_by_rowid_asc() {
|
async fn fetch_new_messages_ordered_by_rowid_asc() {
|
||||||
let (_dir, db_path) = create_test_db();
|
let (_dir, db_path) = create_test_db();
|
||||||
|
|
||||||
// Insert messages out of order
|
// Insert messages out of order
|
||||||
{
|
{
|
||||||
let conn = Connection::open(&db_path).unwrap();
|
let conn = Connection::open(&db_path).unwrap();
|
||||||
conn.execute("INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')", []).unwrap();
|
conn.execute(
|
||||||
|
"INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (30, 1, 'Third', 0)",
|
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (30, 1, 'Third', 0)",
|
||||||
[]
|
[]
|
||||||
|
|
@ -748,7 +791,7 @@ mod tests {
|
||||||
[]
|
[]
|
||||||
).unwrap();
|
).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
||||||
assert_eq!(result.len(), 3);
|
assert_eq!(result.len(), 3);
|
||||||
assert_eq!(result[0].0, 10);
|
assert_eq!(result[0].0, 10);
|
||||||
|
|
@ -766,17 +809,21 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fetch_new_messages_handles_special_characters() {
|
async fn fetch_new_messages_handles_special_characters() {
|
||||||
let (_dir, db_path) = create_test_db();
|
let (_dir, db_path) = create_test_db();
|
||||||
|
|
||||||
// Insert message with special characters (potential SQL injection patterns)
|
// Insert message with special characters (potential SQL injection patterns)
|
||||||
{
|
{
|
||||||
let conn = Connection::open(&db_path).unwrap();
|
let conn = Connection::open(&db_path).unwrap();
|
||||||
conn.execute("INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')", []).unwrap();
|
conn.execute(
|
||||||
|
"INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'Hello \"world'' OR 1=1; DROP TABLE message;--', 0)",
|
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'Hello \"world'' OR 1=1; DROP TABLE message;--', 0)",
|
||||||
[]
|
[]
|
||||||
).unwrap();
|
).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
// The special characters should be preserved, not interpreted as SQL
|
// The special characters should be preserved, not interpreted as SQL
|
||||||
|
|
@ -786,16 +833,20 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fetch_new_messages_handles_unicode() {
|
async fn fetch_new_messages_handles_unicode() {
|
||||||
let (_dir, db_path) = create_test_db();
|
let (_dir, db_path) = create_test_db();
|
||||||
|
|
||||||
{
|
{
|
||||||
let conn = Connection::open(&db_path).unwrap();
|
let conn = Connection::open(&db_path).unwrap();
|
||||||
conn.execute("INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')", []).unwrap();
|
conn.execute(
|
||||||
|
"INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'Hello 🦀 世界 مرحبا', 0)",
|
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'Hello 🦀 世界 مرحبا', 0)",
|
||||||
[]
|
[]
|
||||||
).unwrap();
|
).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
assert_eq!(result[0].2, "Hello 🦀 世界 مرحبا");
|
assert_eq!(result[0].2, "Hello 🦀 世界 مرحبا");
|
||||||
|
|
@ -804,16 +855,21 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fetch_new_messages_handles_empty_text() {
|
async fn fetch_new_messages_handles_empty_text() {
|
||||||
let (_dir, db_path) = create_test_db();
|
let (_dir, db_path) = create_test_db();
|
||||||
|
|
||||||
{
|
{
|
||||||
let conn = Connection::open(&db_path).unwrap();
|
let conn = Connection::open(&db_path).unwrap();
|
||||||
conn.execute("INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')", []).unwrap();
|
conn.execute(
|
||||||
|
"INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, '', 0)",
|
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, '', 0)",
|
||||||
[]
|
[],
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
let result = fetch_new_messages(&db_path, 0).await.unwrap();
|
||||||
// Empty string is NOT NULL, so it's included
|
// Empty string is NOT NULL, so it's included
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
|
|
@ -823,16 +879,20 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fetch_new_messages_negative_rowid_edge_case() {
|
async fn fetch_new_messages_negative_rowid_edge_case() {
|
||||||
let (_dir, db_path) = create_test_db();
|
let (_dir, db_path) = create_test_db();
|
||||||
|
|
||||||
{
|
{
|
||||||
let conn = Connection::open(&db_path).unwrap();
|
let conn = Connection::open(&db_path).unwrap();
|
||||||
conn.execute("INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')", []).unwrap();
|
conn.execute(
|
||||||
|
"INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'Test', 0)",
|
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'Test', 0)",
|
||||||
[]
|
[]
|
||||||
).unwrap();
|
).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Negative rowid should still work (fetch all messages with ROWID > -1)
|
// Negative rowid should still work (fetch all messages with ROWID > -1)
|
||||||
let result = fetch_new_messages(&db_path, -1).await.unwrap();
|
let result = fetch_new_messages(&db_path, -1).await.unwrap();
|
||||||
assert_eq!(result.len(), 1);
|
assert_eq!(result.len(), 1);
|
||||||
|
|
@ -841,16 +901,20 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fetch_new_messages_large_rowid_edge_case() {
|
async fn fetch_new_messages_large_rowid_edge_case() {
|
||||||
let (_dir, db_path) = create_test_db();
|
let (_dir, db_path) = create_test_db();
|
||||||
|
|
||||||
{
|
{
|
||||||
let conn = Connection::open(&db_path).unwrap();
|
let conn = Connection::open(&db_path).unwrap();
|
||||||
conn.execute("INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')", []).unwrap();
|
conn.execute(
|
||||||
|
"INSERT INTO handle (ROWID, id) VALUES (1, '+1234567890')",
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'Test', 0)",
|
"INSERT INTO message (ROWID, handle_id, text, is_from_me) VALUES (10, 1, 'Test', 0)",
|
||||||
[]
|
[]
|
||||||
).unwrap();
|
).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Very large rowid should return empty (no messages after this)
|
// Very large rowid should return empty (no messages after this)
|
||||||
let result = fetch_new_messages(&db_path, i64::MAX - 1).await.unwrap();
|
let result = fetch_new_messages(&db_path, i64::MAX - 1).await.unwrap();
|
||||||
assert!(result.is_empty());
|
assert!(result.is_empty());
|
||||||
|
|
|
||||||
|
|
@ -123,10 +123,12 @@ pub fn build_system_prompt(
|
||||||
" <description>{}</description>",
|
" <description>{}</description>",
|
||||||
skill.description
|
skill.description
|
||||||
);
|
);
|
||||||
let location = workspace_dir
|
let location = skill.location.clone().unwrap_or_else(|| {
|
||||||
.join("skills")
|
workspace_dir
|
||||||
.join(&skill.name)
|
.join("skills")
|
||||||
.join("SKILL.md");
|
.join(&skill.name)
|
||||||
|
.join("SKILL.md")
|
||||||
|
});
|
||||||
let _ = writeln!(prompt, " <location>{}</location>", location.display());
|
let _ = writeln!(prompt, " <location>{}</location>", location.display());
|
||||||
let _ = writeln!(prompt, " </skill>");
|
let _ = writeln!(prompt, " </skill>");
|
||||||
}
|
}
|
||||||
|
|
@ -825,6 +827,7 @@ mod tests {
|
||||||
tags: vec![],
|
tags: vec![],
|
||||||
tools: vec![],
|
tools: vec![],
|
||||||
prompts: vec!["Long prompt content that should NOT appear in system prompt".into()],
|
prompts: vec!["Long prompt content that should NOT appear in system prompt".into()],
|
||||||
|
location: None,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let prompt = build_system_prompt(ws.path(), "model", &[], &skills);
|
let prompt = build_system_prompt(ws.path(), "model", &[], &skills);
|
||||||
|
|
@ -937,8 +940,8 @@ mod tests {
|
||||||
calls: Arc::clone(&calls),
|
calls: Arc::clone(&calls),
|
||||||
});
|
});
|
||||||
|
|
||||||
let (_tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(1);
|
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(1);
|
||||||
let handle = spawn_supervised_listener(channel, _tx, 1, 1);
|
let handle = spawn_supervised_listener(channel, tx, 1, 1);
|
||||||
|
|
||||||
tokio::time::sleep(Duration::from_millis(80)).await;
|
tokio::time::sleep(Duration::from_millis(80)).await;
|
||||||
drop(rx);
|
drop(rx);
|
||||||
|
|
|
||||||
|
|
@ -294,7 +294,7 @@ mod tests {
|
||||||
assert_eq!(msgs[0].sender, "+1234567890");
|
assert_eq!(msgs[0].sender, "+1234567890");
|
||||||
assert_eq!(msgs[0].content, "Hello ZeroClaw!");
|
assert_eq!(msgs[0].content, "Hello ZeroClaw!");
|
||||||
assert_eq!(msgs[0].channel, "whatsapp");
|
assert_eq!(msgs[0].channel, "whatsapp");
|
||||||
assert_eq!(msgs[0].timestamp, 1699999999);
|
assert_eq!(msgs[0].timestamp, 1_699_999_999);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -281,9 +281,11 @@ mod tests {
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
|
|
||||||
fn test_config(tmp: &TempDir) -> Config {
|
fn test_config(tmp: &TempDir) -> Config {
|
||||||
let mut config = Config::default();
|
let config = Config {
|
||||||
config.workspace_dir = tmp.path().join("workspace");
|
workspace_dir: tmp.path().join("workspace"),
|
||||||
config.config_path = tmp.path().join("config.toml");
|
config_path: tmp.path().join("config.toml"),
|
||||||
|
..Config::default()
|
||||||
|
};
|
||||||
std::fs::create_dir_all(&config.workspace_dir).unwrap();
|
std::fs::create_dir_all(&config.workspace_dir).unwrap();
|
||||||
config
|
config
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -186,9 +186,11 @@ mod tests {
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
|
|
||||||
fn test_config(tmp: &TempDir) -> Config {
|
fn test_config(tmp: &TempDir) -> Config {
|
||||||
let mut config = Config::default();
|
let config = Config {
|
||||||
config.workspace_dir = tmp.path().join("workspace");
|
workspace_dir: tmp.path().join("workspace"),
|
||||||
config.config_path = tmp.path().join("config.toml");
|
config_path: tmp.path().join("config.toml"),
|
||||||
|
..Config::default()
|
||||||
|
};
|
||||||
std::fs::create_dir_all(&config.workspace_dir).unwrap();
|
std::fs::create_dir_all(&config.workspace_dir).unwrap();
|
||||||
config
|
config
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -215,9 +215,11 @@ mod tests {
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
|
|
||||||
fn test_config(tmp: &TempDir) -> Config {
|
fn test_config(tmp: &TempDir) -> Config {
|
||||||
let mut config = Config::default();
|
let config = Config {
|
||||||
config.workspace_dir = tmp.path().join("workspace");
|
workspace_dir: tmp.path().join("workspace"),
|
||||||
config.config_path = tmp.path().join("config.toml");
|
config_path: tmp.path().join("config.toml"),
|
||||||
|
..Config::default()
|
||||||
|
};
|
||||||
std::fs::create_dir_all(&config.workspace_dir).unwrap();
|
std::fs::create_dir_all(&config.workspace_dir).unwrap();
|
||||||
config
|
config
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -272,7 +272,7 @@ fn build_windows_icacls_grant_arg(username: &str) -> Option<String> {
|
||||||
/// Hex-decode a hex string to bytes.
|
/// Hex-decode a hex string to bytes.
|
||||||
#[allow(clippy::manual_is_multiple_of)]
|
#[allow(clippy::manual_is_multiple_of)]
|
||||||
fn hex_decode(hex: &str) -> Result<Vec<u8>> {
|
fn hex_decode(hex: &str) -> Result<Vec<u8>> {
|
||||||
if hex.len() % 2 != 0 {
|
if (hex.len() & 1) != 0 {
|
||||||
anyhow::bail!("Hex string has odd length");
|
anyhow::bail!("Hex string has odd length");
|
||||||
}
|
}
|
||||||
(0..hex.len())
|
(0..hex.len())
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,14 @@
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use directories::UserDirs;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::process::Command;
|
||||||
|
use std::time::{Duration, SystemTime};
|
||||||
|
|
||||||
|
const OPEN_SKILLS_REPO_URL: &str = "https://github.com/besoeasy/open-skills";
|
||||||
|
const OPEN_SKILLS_SYNC_MARKER: &str = ".zeroclaw-open-skills-sync";
|
||||||
|
const OPEN_SKILLS_SYNC_INTERVAL_SECS: u64 = 60 * 60 * 24 * 7;
|
||||||
|
|
||||||
/// A skill is a user-defined or community-built capability.
|
/// A skill is a user-defined or community-built capability.
|
||||||
/// Skills live in `~/.zeroclaw/workspace/skills/<name>/SKILL.md`
|
/// Skills live in `~/.zeroclaw/workspace/skills/<name>/SKILL.md`
|
||||||
|
|
@ -19,6 +26,8 @@ pub struct Skill {
|
||||||
pub tools: Vec<SkillTool>,
|
pub tools: Vec<SkillTool>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub prompts: Vec<String>,
|
pub prompts: Vec<String>,
|
||||||
|
#[serde(skip)]
|
||||||
|
pub location: Option<PathBuf>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A tool defined by a skill (shell command, HTTP call, etc.)
|
/// A tool defined by a skill (shell command, HTTP call, etc.)
|
||||||
|
|
@ -62,14 +71,29 @@ fn default_version() -> String {
|
||||||
|
|
||||||
/// Load all skills from the workspace skills directory
|
/// Load all skills from the workspace skills directory
|
||||||
pub fn load_skills(workspace_dir: &Path) -> Vec<Skill> {
|
pub fn load_skills(workspace_dir: &Path) -> Vec<Skill> {
|
||||||
|
let mut skills = Vec::new();
|
||||||
|
|
||||||
|
if let Some(open_skills_dir) = ensure_open_skills_repo() {
|
||||||
|
skills.extend(load_open_skills(&open_skills_dir));
|
||||||
|
}
|
||||||
|
|
||||||
|
skills.extend(load_workspace_skills(workspace_dir));
|
||||||
|
skills
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_workspace_skills(workspace_dir: &Path) -> Vec<Skill> {
|
||||||
let skills_dir = workspace_dir.join("skills");
|
let skills_dir = workspace_dir.join("skills");
|
||||||
|
load_skills_from_directory(&skills_dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_skills_from_directory(skills_dir: &Path) -> Vec<Skill> {
|
||||||
if !skills_dir.exists() {
|
if !skills_dir.exists() {
|
||||||
return Vec::new();
|
return Vec::new();
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut skills = Vec::new();
|
let mut skills = Vec::new();
|
||||||
|
|
||||||
let Ok(entries) = std::fs::read_dir(&skills_dir) else {
|
let Ok(entries) = std::fs::read_dir(skills_dir) else {
|
||||||
return skills;
|
return skills;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -97,6 +121,172 @@ pub fn load_skills(workspace_dir: &Path) -> Vec<Skill> {
|
||||||
skills
|
skills
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn load_open_skills(repo_dir: &Path) -> Vec<Skill> {
|
||||||
|
let mut skills = Vec::new();
|
||||||
|
|
||||||
|
let Ok(entries) = std::fs::read_dir(repo_dir) else {
|
||||||
|
return skills;
|
||||||
|
};
|
||||||
|
|
||||||
|
for entry in entries.flatten() {
|
||||||
|
let path = entry.path();
|
||||||
|
if !path.is_file() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let is_markdown = path
|
||||||
|
.extension()
|
||||||
|
.and_then(|ext| ext.to_str())
|
||||||
|
.is_some_and(|ext| ext.eq_ignore_ascii_case("md"));
|
||||||
|
if !is_markdown {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let is_readme = path
|
||||||
|
.file_name()
|
||||||
|
.and_then(|name| name.to_str())
|
||||||
|
.is_some_and(|name| name.eq_ignore_ascii_case("README.md"));
|
||||||
|
if is_readme {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(skill) = load_open_skill_md(&path) {
|
||||||
|
skills.push(skill);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
skills
|
||||||
|
}
|
||||||
|
|
||||||
|
fn open_skills_enabled() -> bool {
|
||||||
|
if let Ok(raw) = std::env::var("ZEROCLAW_OPEN_SKILLS_ENABLED") {
|
||||||
|
let value = raw.trim().to_ascii_lowercase();
|
||||||
|
return !matches!(value.as_str(), "0" | "false" | "off" | "no");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep tests deterministic and network-free by default.
|
||||||
|
!cfg!(test)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_open_skills_dir() -> Option<PathBuf> {
|
||||||
|
if let Ok(path) = std::env::var("ZEROCLAW_OPEN_SKILLS_DIR") {
|
||||||
|
let trimmed = path.trim();
|
||||||
|
if !trimmed.is_empty() {
|
||||||
|
return Some(PathBuf::from(trimmed));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
UserDirs::new().map(|dirs| dirs.home_dir().join("open-skills"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ensure_open_skills_repo() -> Option<PathBuf> {
|
||||||
|
if !open_skills_enabled() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let repo_dir = resolve_open_skills_dir()?;
|
||||||
|
|
||||||
|
if !repo_dir.exists() {
|
||||||
|
if !clone_open_skills_repo(&repo_dir) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let _ = mark_open_skills_synced(&repo_dir);
|
||||||
|
return Some(repo_dir);
|
||||||
|
}
|
||||||
|
|
||||||
|
if should_sync_open_skills(&repo_dir) {
|
||||||
|
if pull_open_skills_repo(&repo_dir) {
|
||||||
|
let _ = mark_open_skills_synced(&repo_dir);
|
||||||
|
} else {
|
||||||
|
tracing::warn!(
|
||||||
|
"open-skills update failed; using local copy from {}",
|
||||||
|
repo_dir.display()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(repo_dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clone_open_skills_repo(repo_dir: &Path) -> bool {
|
||||||
|
if let Some(parent) = repo_dir.parent() {
|
||||||
|
if let Err(err) = std::fs::create_dir_all(parent) {
|
||||||
|
tracing::warn!(
|
||||||
|
"failed to create open-skills parent directory {}: {err}",
|
||||||
|
parent.display()
|
||||||
|
);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let output = Command::new("git")
|
||||||
|
.args(["clone", "--depth", "1", OPEN_SKILLS_REPO_URL])
|
||||||
|
.arg(repo_dir)
|
||||||
|
.output();
|
||||||
|
|
||||||
|
match output {
|
||||||
|
Ok(result) if result.status.success() => {
|
||||||
|
tracing::info!("initialized open-skills at {}", repo_dir.display());
|
||||||
|
true
|
||||||
|
}
|
||||||
|
Ok(result) => {
|
||||||
|
let stderr = String::from_utf8_lossy(&result.stderr);
|
||||||
|
tracing::warn!("failed to clone open-skills: {stderr}");
|
||||||
|
false
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
tracing::warn!("failed to run git clone for open-skills: {err}");
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pull_open_skills_repo(repo_dir: &Path) -> bool {
|
||||||
|
// If user points to a non-git directory via env var, keep using it without pulling.
|
||||||
|
if !repo_dir.join(".git").exists() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
let output = Command::new("git")
|
||||||
|
.arg("-C")
|
||||||
|
.arg(repo_dir)
|
||||||
|
.args(["pull", "--ff-only"])
|
||||||
|
.output();
|
||||||
|
|
||||||
|
match output {
|
||||||
|
Ok(result) if result.status.success() => true,
|
||||||
|
Ok(result) => {
|
||||||
|
let stderr = String::from_utf8_lossy(&result.stderr);
|
||||||
|
tracing::warn!("failed to pull open-skills updates: {stderr}");
|
||||||
|
false
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
tracing::warn!("failed to run git pull for open-skills: {err}");
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn should_sync_open_skills(repo_dir: &Path) -> bool {
|
||||||
|
let marker = repo_dir.join(OPEN_SKILLS_SYNC_MARKER);
|
||||||
|
let Ok(metadata) = std::fs::metadata(marker) else {
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
let Ok(modified_at) = metadata.modified() else {
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
let Ok(age) = SystemTime::now().duration_since(modified_at) else {
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
age >= Duration::from_secs(OPEN_SKILLS_SYNC_INTERVAL_SECS)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mark_open_skills_synced(repo_dir: &Path) -> Result<()> {
|
||||||
|
std::fs::write(repo_dir.join(OPEN_SKILLS_SYNC_MARKER), b"synced")?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Load a skill from a SKILL.toml manifest
|
/// Load a skill from a SKILL.toml manifest
|
||||||
fn load_skill_toml(path: &Path) -> Result<Skill> {
|
fn load_skill_toml(path: &Path) -> Result<Skill> {
|
||||||
let content = std::fs::read_to_string(path)?;
|
let content = std::fs::read_to_string(path)?;
|
||||||
|
|
@ -110,6 +300,7 @@ fn load_skill_toml(path: &Path) -> Result<Skill> {
|
||||||
tags: manifest.skill.tags,
|
tags: manifest.skill.tags,
|
||||||
tools: manifest.tools,
|
tools: manifest.tools,
|
||||||
prompts: manifest.prompts,
|
prompts: manifest.prompts,
|
||||||
|
location: Some(path.to_path_buf()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -122,25 +313,47 @@ fn load_skill_md(path: &Path, dir: &Path) -> Result<Skill> {
|
||||||
.unwrap_or("unknown")
|
.unwrap_or("unknown")
|
||||||
.to_string();
|
.to_string();
|
||||||
|
|
||||||
// Extract description from first non-heading line
|
|
||||||
let description = content
|
|
||||||
.lines()
|
|
||||||
.find(|l| !l.starts_with('#') && !l.trim().is_empty())
|
|
||||||
.unwrap_or("No description")
|
|
||||||
.trim()
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
Ok(Skill {
|
Ok(Skill {
|
||||||
name,
|
name,
|
||||||
description,
|
description: extract_description(&content),
|
||||||
version: "0.1.0".to_string(),
|
version: "0.1.0".to_string(),
|
||||||
author: None,
|
author: None,
|
||||||
tags: Vec::new(),
|
tags: Vec::new(),
|
||||||
tools: Vec::new(),
|
tools: Vec::new(),
|
||||||
prompts: vec![content],
|
prompts: vec![content],
|
||||||
|
location: Some(path.to_path_buf()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn load_open_skill_md(path: &Path) -> Result<Skill> {
|
||||||
|
let content = std::fs::read_to_string(path)?;
|
||||||
|
let name = path
|
||||||
|
.file_stem()
|
||||||
|
.and_then(|n| n.to_str())
|
||||||
|
.unwrap_or("open-skill")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
Ok(Skill {
|
||||||
|
name,
|
||||||
|
description: extract_description(&content),
|
||||||
|
version: "open-skills".to_string(),
|
||||||
|
author: Some("besoeasy/open-skills".to_string()),
|
||||||
|
tags: vec!["open-skills".to_string()],
|
||||||
|
tools: Vec::new(),
|
||||||
|
prompts: vec![content],
|
||||||
|
location: Some(path.to_path_buf()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_description(content: &str) -> String {
|
||||||
|
content
|
||||||
|
.lines()
|
||||||
|
.find(|line| !line.starts_with('#') && !line.trim().is_empty())
|
||||||
|
.unwrap_or("No description")
|
||||||
|
.trim()
|
||||||
|
.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
/// Build a system prompt addition from all loaded skills
|
/// Build a system prompt addition from all loaded skills
|
||||||
pub fn skills_to_prompt(skills: &[Skill]) -> String {
|
pub fn skills_to_prompt(skills: &[Skill]) -> String {
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
|
|
@ -468,6 +681,7 @@ command = "echo hello"
|
||||||
tags: vec![],
|
tags: vec![],
|
||||||
tools: vec![],
|
tools: vec![],
|
||||||
prompts: vec!["Do the thing.".to_string()],
|
prompts: vec!["Do the thing.".to_string()],
|
||||||
|
location: None,
|
||||||
}];
|
}];
|
||||||
let prompt = skills_to_prompt(&skills);
|
let prompt = skills_to_prompt(&skills);
|
||||||
assert!(prompt.contains("test"));
|
assert!(prompt.contains("test"));
|
||||||
|
|
@ -657,6 +871,7 @@ description = "Bare minimum"
|
||||||
args: HashMap::new(),
|
args: HashMap::new(),
|
||||||
}],
|
}],
|
||||||
prompts: vec![],
|
prompts: vec![],
|
||||||
|
location: None,
|
||||||
}];
|
}];
|
||||||
let prompt = skills_to_prompt(&skills);
|
let prompt = skills_to_prompt(&skills);
|
||||||
assert!(prompt.contains("weather"));
|
assert!(prompt.contains("weather"));
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod symlink_tests {
|
mod tests {
|
||||||
use crate::skills::skills_dir;
|
use crate::skills::skills_dir;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
|
|
|
||||||
|
|
@ -365,8 +365,8 @@ impl BrowserTool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
|
#[async_trait]
|
||||||
impl Tool for BrowserTool {
|
impl Tool for BrowserTool {
|
||||||
fn name(&self) -> &str {
|
fn name(&self) -> &str {
|
||||||
"browser"
|
"browser"
|
||||||
|
|
@ -750,7 +750,7 @@ mod tests {
|
||||||
let domains = vec![
|
let domains = vec![
|
||||||
" Example.COM ".into(),
|
" Example.COM ".into(),
|
||||||
"docs.example.com".into(),
|
"docs.example.com".into(),
|
||||||
"".into(),
|
String::new(),
|
||||||
];
|
];
|
||||||
let normalized = normalize_domains(domains);
|
let normalized = normalize_domains(domains);
|
||||||
assert_eq!(normalized, vec!["example.com", "docs.example.com"]);
|
assert_eq!(normalized, vec!["example.com", "docs.example.com"]);
|
||||||
|
|
|
||||||
|
|
@ -84,9 +84,8 @@ fn pattern_matches(pattern: &str, path: &str) -> bool {
|
||||||
fn is_excluded(patterns: &[String], path: &str) -> bool {
|
fn is_excluded(patterns: &[String], path: &str) -> bool {
|
||||||
let mut excluded = false;
|
let mut excluded = false;
|
||||||
for pattern in patterns {
|
for pattern in patterns {
|
||||||
if pattern.starts_with('!') {
|
if let Some(negated) = pattern.strip_prefix('!') {
|
||||||
// Negation pattern - re-include
|
// Negation pattern - re-include
|
||||||
let negated = &pattern[1..];
|
|
||||||
if pattern_matches(negated, path) {
|
if pattern_matches(negated, path) {
|
||||||
excluded = false;
|
excluded = false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue