1use std::path::PathBuf;
11
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14use tokio::net::TcpListener;
15use tracing::{info, warn};
16
17#[derive(Debug, Error)]
19pub enum ListenError {
20 #[error("TCP bind failed on {addr}: {source}")]
22 TcpBind {
23 addr: String,
24 source: std::io::Error,
25 },
26
27 #[error("Unix socket bind failed on {path}: {source}")]
29 UnixBind {
30 path: String,
31 source: std::io::Error,
32 },
33
34 #[error("systemd socket activation error: {0}")]
36 SystemdActivation(String),
37
38 #[error("I/O error: {0}")]
40 Io(#[from] std::io::Error),
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45#[serde(tag = "type", rename_all = "snake_case")]
46pub enum ListenConfig {
47 Tcp {
49 address: String,
51 tls: bool,
53 },
54 Unix {
56 path: PathBuf,
58 },
59 Systemd,
61}
62
63pub enum Listener {
65 Tcp(TcpListener),
67 #[cfg(unix)]
69 Unix(tokio::net::UnixListener),
70}
71
72impl Listener {
73 pub async fn bind(config: &ListenConfig) -> Result<Self, ListenError> {
75 match config {
76 ListenConfig::Tcp { address, tls } => {
77 let listener =
78 TcpListener::bind(address)
79 .await
80 .map_err(|e| ListenError::TcpBind {
81 addr: address.clone(),
82 source: e,
83 })?;
84 info!(
85 address = %address,
86 tls = %tls,
87 "TCP listener bound"
88 );
89 Ok(Self::Tcp(listener))
90 }
91 ListenConfig::Unix { path } => {
92 #[cfg(unix)]
93 {
94 if path.exists() {
96 warn!(path = %path.display(), "removing stale Unix socket");
97 std::fs::remove_file(path).ok();
98 }
99 let listener = tokio::net::UnixListener::bind(path).map_err(|e| {
100 ListenError::UnixBind {
101 path: path.display().to_string(),
102 source: e,
103 }
104 })?;
105 info!(path = %path.display(), "Unix socket listener bound");
106 Ok(Self::Unix(listener))
107 }
108 #[cfg(not(unix))]
109 {
110 let _ = path;
111 Err(ListenError::Io(std::io::Error::new(
112 std::io::ErrorKind::Unsupported,
113 "Unix sockets not supported on this platform",
114 )))
115 }
116 }
117 ListenConfig::Systemd => {
118 let listener = activate_systemd_socket()?;
119 info!("systemd socket activation listener acquired");
120 Ok(Self::Tcp(listener))
121 }
122 }
123 }
124}
125
126fn activate_systemd_socket() -> Result<TcpListener, ListenError> {
131 let listen_pid: u32 = std::env::var("LISTEN_PID")
132 .map_err(|_| {
133 ListenError::SystemdActivation("LISTEN_PID not set; not running under systemd".into())
134 })?
135 .parse()
136 .map_err(|e| ListenError::SystemdActivation(format!("invalid LISTEN_PID: {e}")))?;
137
138 let current_pid = std::process::id();
139 if listen_pid != current_pid {
140 return Err(ListenError::SystemdActivation(format!(
141 "LISTEN_PID {listen_pid} does not match current PID {current_pid}"
142 )));
143 }
144
145 let listen_fds: u32 = std::env::var("LISTEN_FDS")
146 .map_err(|_| ListenError::SystemdActivation("LISTEN_FDS not set".into()))?
147 .parse()
148 .map_err(|e| ListenError::SystemdActivation(format!("invalid LISTEN_FDS: {e}")))?;
149
150 if listen_fds == 0 {
151 return Err(ListenError::SystemdActivation(
152 "LISTEN_FDS is 0; no sockets passed".into(),
153 ));
154 }
155
156 const SD_LISTEN_FDS_START: i32 = 3;
158
159 #[cfg(unix)]
160 {
161 use std::os::unix::io::FromRawFd;
162 let std_listener = unsafe { std::net::TcpListener::from_raw_fd(SD_LISTEN_FDS_START) };
163 std_listener.set_nonblocking(true).map_err(|e| {
164 ListenError::SystemdActivation(format!("failed to set nonblocking: {e}"))
165 })?;
166 TcpListener::from_std(std_listener)
167 .map_err(|e| ListenError::SystemdActivation(format!("tokio wrap failed: {e}")))
168 }
169
170 #[cfg(not(unix))]
171 {
172 let _ = SD_LISTEN_FDS_START;
173 Err(ListenError::SystemdActivation(
174 "systemd socket activation not supported on this platform".into(),
175 ))
176 }
177}
178
179pub async fn shutdown_signal() {
184 #[cfg(unix)]
185 {
186 use tokio::signal::unix::{SignalKind, signal};
187 let mut sigterm = signal(SignalKind::terminate()).expect("failed to register SIGTERM");
188 let mut sigint = signal(SignalKind::interrupt()).expect("failed to register SIGINT");
189
190 tokio::select! {
191 _ = sigterm.recv() => info!("received SIGTERM, initiating graceful shutdown"),
192 _ = sigint.recv() => info!("received SIGINT, initiating graceful shutdown"),
193 }
194 }
195
196 #[cfg(not(unix))]
197 {
198 tokio::signal::ctrl_c()
199 .await
200 .expect("failed to register Ctrl+C handler");
201 info!("received Ctrl+C, initiating graceful shutdown");
202 }
203}