feat(discord): add mention_only config for @-mention trigger (#529)
When mention_only is true, the bot only responds to messages that @-mention the bot. Other messages in the guild are silently ignored. Also strips the bot mention from content before processing. Co-authored-by: Will Sarg <12886992+willsarg@users.noreply.github.com>
This commit is contained in:
parent
a2986db3d6
commit
5b5d9fe77f
6 changed files with 56 additions and 14 deletions
|
|
@ -11,6 +11,7 @@ pub struct DiscordChannel {
|
||||||
guild_id: Option<String>,
|
guild_id: Option<String>,
|
||||||
allowed_users: Vec<String>,
|
allowed_users: Vec<String>,
|
||||||
listen_to_bots: bool,
|
listen_to_bots: bool,
|
||||||
|
mention_only: bool,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
typing_handle: std::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
|
typing_handle: std::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
|
||||||
}
|
}
|
||||||
|
|
@ -21,12 +22,14 @@ impl DiscordChannel {
|
||||||
guild_id: Option<String>,
|
guild_id: Option<String>,
|
||||||
allowed_users: Vec<String>,
|
allowed_users: Vec<String>,
|
||||||
listen_to_bots: bool,
|
listen_to_bots: bool,
|
||||||
|
mention_only: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
bot_token,
|
bot_token,
|
||||||
guild_id,
|
guild_id,
|
||||||
allowed_users,
|
allowed_users,
|
||||||
listen_to_bots,
|
listen_to_bots,
|
||||||
|
mention_only,
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
typing_handle: std::sync::Mutex::new(None),
|
typing_handle: std::sync::Mutex::new(None),
|
||||||
}
|
}
|
||||||
|
|
@ -343,6 +346,22 @@ impl Channel for DiscordChannel {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip messages that don't @-mention the bot (when mention_only is enabled)
|
||||||
|
if self.mention_only {
|
||||||
|
let mention_tag = format!("<@{bot_user_id}>");
|
||||||
|
if !content.contains(&mention_tag) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip the bot mention from content so the agent sees clean text
|
||||||
|
let clean_content = if self.mention_only {
|
||||||
|
let mention_tag = format!("<@{bot_user_id}>");
|
||||||
|
content.replace(&mention_tag, "").trim().to_string()
|
||||||
|
} else {
|
||||||
|
content.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
||||||
let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string();
|
let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string();
|
||||||
|
|
||||||
|
|
@ -354,7 +373,7 @@ impl Channel for DiscordChannel {
|
||||||
},
|
},
|
||||||
sender: author_id.to_string(),
|
sender: author_id.to_string(),
|
||||||
reply_to: channel_id.clone(),
|
reply_to: channel_id.clone(),
|
||||||
content: content.to_string(),
|
content: clean_content,
|
||||||
channel: "discord".to_string(),
|
channel: "discord".to_string(),
|
||||||
timestamp: std::time::SystemTime::now()
|
timestamp: std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
|
@ -424,7 +443,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn discord_channel_name() {
|
fn discord_channel_name() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||||
assert_eq!(ch.name(), "discord");
|
assert_eq!(ch.name(), "discord");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -445,21 +464,27 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn empty_allowlist_denies_everyone() {
|
fn empty_allowlist_denies_everyone() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||||
assert!(!ch.is_user_allowed("12345"));
|
assert!(!ch.is_user_allowed("12345"));
|
||||||
assert!(!ch.is_user_allowed("anyone"));
|
assert!(!ch.is_user_allowed("anyone"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn wildcard_allows_everyone() {
|
fn wildcard_allows_everyone() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false, false);
|
||||||
assert!(ch.is_user_allowed("12345"));
|
assert!(ch.is_user_allowed("12345"));
|
||||||
assert!(ch.is_user_allowed("anyone"));
|
assert!(ch.is_user_allowed("anyone"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn specific_allowlist_filters() {
|
fn specific_allowlist_filters() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "222".into()], false);
|
let ch = DiscordChannel::new(
|
||||||
|
"fake".into(),
|
||||||
|
None,
|
||||||
|
vec!["111".into(), "222".into()],
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
);
|
||||||
assert!(ch.is_user_allowed("111"));
|
assert!(ch.is_user_allowed("111"));
|
||||||
assert!(ch.is_user_allowed("222"));
|
assert!(ch.is_user_allowed("222"));
|
||||||
assert!(!ch.is_user_allowed("333"));
|
assert!(!ch.is_user_allowed("333"));
|
||||||
|
|
@ -468,7 +493,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn allowlist_is_exact_match_not_substring() {
|
fn allowlist_is_exact_match_not_substring() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false);
|
||||||
assert!(!ch.is_user_allowed("1111"));
|
assert!(!ch.is_user_allowed("1111"));
|
||||||
assert!(!ch.is_user_allowed("11"));
|
assert!(!ch.is_user_allowed("11"));
|
||||||
assert!(!ch.is_user_allowed("0111"));
|
assert!(!ch.is_user_allowed("0111"));
|
||||||
|
|
@ -476,20 +501,26 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn allowlist_empty_string_user_id() {
|
fn allowlist_empty_string_user_id() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false);
|
||||||
assert!(!ch.is_user_allowed(""));
|
assert!(!ch.is_user_allowed(""));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn allowlist_with_wildcard_and_specific() {
|
fn allowlist_with_wildcard_and_specific() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "*".into()], false);
|
let ch = DiscordChannel::new(
|
||||||
|
"fake".into(),
|
||||||
|
None,
|
||||||
|
vec!["111".into(), "*".into()],
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
);
|
||||||
assert!(ch.is_user_allowed("111"));
|
assert!(ch.is_user_allowed("111"));
|
||||||
assert!(ch.is_user_allowed("anyone_else"));
|
assert!(ch.is_user_allowed("anyone_else"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn allowlist_case_sensitive() {
|
fn allowlist_case_sensitive() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false, false);
|
||||||
assert!(ch.is_user_allowed("ABC"));
|
assert!(ch.is_user_allowed("ABC"));
|
||||||
assert!(!ch.is_user_allowed("abc"));
|
assert!(!ch.is_user_allowed("abc"));
|
||||||
assert!(!ch.is_user_allowed("Abc"));
|
assert!(!ch.is_user_allowed("Abc"));
|
||||||
|
|
@ -664,14 +695,14 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn typing_handle_starts_as_none() {
|
fn typing_handle_starts_as_none() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||||
let guard = ch.typing_handle.lock().unwrap();
|
let guard = ch.typing_handle.lock().unwrap();
|
||||||
assert!(guard.is_none());
|
assert!(guard.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn start_typing_sets_handle() {
|
async fn start_typing_sets_handle() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||||
let _ = ch.start_typing("123456").await;
|
let _ = ch.start_typing("123456").await;
|
||||||
let guard = ch.typing_handle.lock().unwrap();
|
let guard = ch.typing_handle.lock().unwrap();
|
||||||
assert!(guard.is_some());
|
assert!(guard.is_some());
|
||||||
|
|
@ -679,7 +710,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn stop_typing_clears_handle() {
|
async fn stop_typing_clears_handle() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||||
let _ = ch.start_typing("123456").await;
|
let _ = ch.start_typing("123456").await;
|
||||||
let _ = ch.stop_typing("123456").await;
|
let _ = ch.stop_typing("123456").await;
|
||||||
let guard = ch.typing_handle.lock().unwrap();
|
let guard = ch.typing_handle.lock().unwrap();
|
||||||
|
|
@ -688,14 +719,14 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn stop_typing_is_idempotent() {
|
async fn stop_typing_is_idempotent() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||||
assert!(ch.stop_typing("123456").await.is_ok());
|
assert!(ch.stop_typing("123456").await.is_ok());
|
||||||
assert!(ch.stop_typing("123456").await.is_ok());
|
assert!(ch.stop_typing("123456").await.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn start_typing_replaces_existing_task() {
|
async fn start_typing_replaces_existing_task() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||||
let _ = ch.start_typing("111").await;
|
let _ = ch.start_typing("111").await;
|
||||||
let _ = ch.start_typing("222").await;
|
let _ = ch.start_typing("222").await;
|
||||||
let guard = ch.typing_handle.lock().unwrap();
|
let guard = ch.typing_handle.lock().unwrap();
|
||||||
|
|
|
||||||
|
|
@ -620,6 +620,7 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
|
||||||
dc.guild_id.clone(),
|
dc.guild_id.clone(),
|
||||||
dc.allowed_users.clone(),
|
dc.allowed_users.clone(),
|
||||||
dc.listen_to_bots,
|
dc.listen_to_bots,
|
||||||
|
dc.mention_only,
|
||||||
)),
|
)),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
@ -906,6 +907,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
dc.guild_id.clone(),
|
dc.guild_id.clone(),
|
||||||
dc.allowed_users.clone(),
|
dc.allowed_users.clone(),
|
||||||
dc.listen_to_bots,
|
dc.listen_to_bots,
|
||||||
|
dc.mention_only,
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,6 +37,7 @@ mod tests {
|
||||||
guild_id: Some("123".into()),
|
guild_id: Some("123".into()),
|
||||||
allowed_users: vec![],
|
allowed_users: vec![],
|
||||||
listen_to_bots: false,
|
listen_to_bots: false,
|
||||||
|
mention_only: false,
|
||||||
};
|
};
|
||||||
|
|
||||||
let lark = LarkConfig {
|
let lark = LarkConfig {
|
||||||
|
|
|
||||||
|
|
@ -1319,6 +1319,10 @@ pub struct DiscordConfig {
|
||||||
/// The bot still ignores its own messages to prevent feedback loops.
|
/// The bot still ignores its own messages to prevent feedback loops.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub listen_to_bots: bool,
|
pub listen_to_bots: bool,
|
||||||
|
/// When true, only respond to messages that @-mention the bot.
|
||||||
|
/// Other messages in the guild are silently ignored.
|
||||||
|
#[serde(default)]
|
||||||
|
pub mention_only: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -2392,6 +2396,7 @@ tool_dispatcher = "xml"
|
||||||
guild_id: Some("12345".into()),
|
guild_id: Some("12345".into()),
|
||||||
allowed_users: vec![],
|
allowed_users: vec![],
|
||||||
listen_to_bots: false,
|
listen_to_bots: false,
|
||||||
|
mention_only: false,
|
||||||
};
|
};
|
||||||
let json = serde_json::to_string(&dc).unwrap();
|
let json = serde_json::to_string(&dc).unwrap();
|
||||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||||
|
|
@ -2406,6 +2411,7 @@ tool_dispatcher = "xml"
|
||||||
guild_id: None,
|
guild_id: None,
|
||||||
allowed_users: vec![],
|
allowed_users: vec![],
|
||||||
listen_to_bots: false,
|
listen_to_bots: false,
|
||||||
|
mention_only: false,
|
||||||
};
|
};
|
||||||
let json = serde_json::to_string(&dc).unwrap();
|
let json = serde_json::to_string(&dc).unwrap();
|
||||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||||
|
|
|
||||||
|
|
@ -245,6 +245,7 @@ async fn deliver_if_configured(config: &Config, job: &CronJob, output: &str) ->
|
||||||
dc.guild_id.clone(),
|
dc.guild_id.clone(),
|
||||||
dc.allowed_users.clone(),
|
dc.allowed_users.clone(),
|
||||||
dc.listen_to_bots,
|
dc.listen_to_bots,
|
||||||
|
dc.mention_only,
|
||||||
);
|
);
|
||||||
channel.send(output, target).await?;
|
channel.send(output, target).await?;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2586,6 +2586,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
guild_id: if guild.is_empty() { None } else { Some(guild) },
|
guild_id: if guild.is_empty() { None } else { Some(guild) },
|
||||||
allowed_users,
|
allowed_users,
|
||||||
listen_to_bots: false,
|
listen_to_bots: false,
|
||||||
|
mention_only: false,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
2 => {
|
2 => {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue