i3status_rs/blocks/calendar/
auth.rs

1use base64::Engine as _;
2use oauth2::basic::{
3    BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse,
4    BasicTokenResponse,
5};
6use oauth2::{
7    AuthUrl, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken, EndpointNotSet,
8    EndpointSet, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope,
9    StandardRevocableToken, TokenResponse as _, TokenUrl,
10};
11use reqwest;
12use reqwest::Url;
13use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue};
14use std::path::{Path, PathBuf};
15use std::sync::LazyLock;
16use tokio::fs::File;
17use tokio::io::{AsyncBufReadExt as _, AsyncReadExt as _, AsyncWriteExt as _, BufReader};
18use tokio::net::TcpListener;
19
20use super::CalendarError;
21use crate::{APP_USER_AGENT, REQWEST_TIMEOUT};
22
23static REQWEST_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
24    reqwest::Client::builder()
25        .user_agent(APP_USER_AGENT)
26        .timeout(REQWEST_TIMEOUT)
27        // Following redirects opens the client up to SSRF vulnerabilities.
28        .redirect(reqwest::redirect::Policy::none())
29        .build()
30        .unwrap()
31});
32
33type BasicClient<
34    HasAuthUrl = EndpointSet,
35    HasDeviceAuthUrl = EndpointNotSet,
36    HasIntrospectionUrl = EndpointNotSet,
37    HasRevocationUrl = EndpointNotSet,
38    HasTokenUrl = EndpointSet,
39> = Client<
40    BasicErrorResponse,
41    BasicTokenResponse,
42    BasicTokenIntrospectionResponse,
43    StandardRevocableToken,
44    BasicRevocationErrorResponse,
45    HasAuthUrl,
46    HasDeviceAuthUrl,
47    HasIntrospectionUrl,
48    HasRevocationUrl,
49    HasTokenUrl,
50>;
51
52pub enum Auth {
53    Unauthenticated,
54    Basic(Basic),
55    OAuth2(Box<OAuth2>),
56}
57
58impl Auth {
59    pub fn oauth2(flow: OAuth2Flow, token_store: TokenStore, scopes: Vec<Scope>) -> Self {
60        Self::OAuth2(Box::new(OAuth2 {
61            flow,
62            token_store,
63            scopes,
64        }))
65    }
66    pub fn basic(username: String, password: String) -> Self {
67        Self::Basic(Basic { username, password })
68    }
69    pub async fn headers(&mut self) -> HeaderMap {
70        match self {
71            Auth::Unauthenticated => HeaderMap::new(),
72            Auth::Basic(auth) => auth.headers().await,
73            Auth::OAuth2(auth) => auth.headers().await,
74        }
75    }
76
77    pub async fn handle_error(&mut self, error: reqwest::Error) -> Result<(), CalendarError> {
78        match self {
79            Auth::Unauthenticated | Auth::Basic(_) => Err(CalendarError::Http(error)),
80            Auth::OAuth2(auth) => auth.handle_error(error).await,
81        }
82    }
83
84    pub async fn authorize(&mut self) -> Result<Authorize, CalendarError> {
85        match self {
86            Auth::Unauthenticated | Auth::Basic(_) => Ok(Authorize::Completed),
87            Auth::OAuth2(auth) => Ok(Authorize::AskUser(auth.authorize().await?)),
88        }
89    }
90    pub async fn ask_user(&mut self, authorize_url: AuthorizeUrl) -> Result<(), CalendarError> {
91        match self {
92            Auth::Unauthenticated | Auth::Basic(_) => Ok(()),
93            Auth::OAuth2(auth) => auth.ask_user(authorize_url).await,
94        }
95    }
96}
97
98pub struct Basic {
99    username: String,
100    password: String,
101}
102
103impl Basic {
104    pub async fn headers(&mut self) -> HeaderMap {
105        let mut headers = HeaderMap::new();
106        let header =
107            base64::prelude::BASE64_STANDARD.encode(format!("{}:{}", self.username, self.password));
108        let mut header_value = HeaderValue::from_str(format!("Basic {header}").as_str())
109            .expect("A valid basic header");
110        header_value.set_sensitive(true);
111        headers.insert(AUTHORIZATION, header_value);
112        headers
113    }
114}
115
116pub struct OAuth2 {
117    flow: OAuth2Flow,
118    token_store: TokenStore,
119    scopes: Vec<Scope>,
120}
121
122impl OAuth2 {
123    pub async fn headers(&mut self) -> HeaderMap {
124        let mut headers = HeaderMap::new();
125        if let Some(token) = self.token_store.get().await {
126            let mut auth_value =
127                HeaderValue::from_str(format!("Bearer {}", token.access_token().secret()).as_str())
128                    .expect("A valid access token");
129            auth_value.set_sensitive(true);
130            headers.insert(AUTHORIZATION, auth_value);
131        }
132        headers
133    }
134
135    async fn handle_error(&mut self, error: reqwest::Error) -> Result<(), CalendarError> {
136        if let Some(status) = error.status() {
137            if status == 401 {
138                match self
139                    .token_store
140                    .get()
141                    .await
142                    .and_then(|t| t.refresh_token().cloned())
143                {
144                    Some(refresh_token) => {
145                        let mut token = self.flow.refresh_token_exchange(&refresh_token).await?;
146                        if token.refresh_token().is_none() {
147                            token.set_refresh_token(Some(refresh_token));
148                        }
149                        self.token_store.store(token).await?;
150                        return Ok(());
151                    }
152                    None => return Err(CalendarError::AuthRequired),
153                }
154            }
155            if status == 403 {
156                return Err(CalendarError::AuthRequired);
157            }
158        }
159        Err(CalendarError::Http(error))
160    }
161
162    async fn authorize(&mut self) -> Result<AuthorizeUrl, CalendarError> {
163        Ok(self.flow.authorize_url(self.scopes.clone()))
164    }
165
166    async fn ask_user(&mut self, authorize_url: AuthorizeUrl) -> Result<(), CalendarError> {
167        let token = self.flow.redirect(authorize_url).await?;
168        self.token_store.store(token).await?;
169        Ok(())
170    }
171}
172pub struct OAuth2Flow {
173    client: BasicClient,
174    redirect_port: u16,
175}
176
177impl OAuth2Flow {
178    pub fn new(
179        client_id: ClientId,
180        client_secret: ClientSecret,
181        auth_url: AuthUrl,
182        token_url: TokenUrl,
183        redirect_port: u16,
184    ) -> Self {
185        Self {
186            client: BasicClient::new(client_id)
187                .set_client_secret(client_secret)
188                .set_auth_uri(auth_url)
189                .set_token_uri(token_url)
190                .set_redirect_uri(
191                    RedirectUrl::new(format!("http://localhost:{redirect_port}").to_string())
192                        .expect("A valid redirect URL"),
193                ),
194            redirect_port,
195        }
196    }
197
198    pub fn authorize_url(&self, scopes: Vec<Scope>) -> AuthorizeUrl {
199        let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256();
200        let (authorize_url, csrf_token) = self
201            .client
202            .authorize_url(CsrfToken::new_random)
203            .add_scopes(scopes)
204            .set_pkce_challenge(pkce_code_challenge.clone())
205            .url();
206        AuthorizeUrl {
207            pkce_code_verifier,
208            url: authorize_url,
209            csrf_token,
210        }
211    }
212
213    pub async fn refresh_token_exchange(
214        &self,
215        token: &RefreshToken,
216    ) -> Result<BasicTokenResponse, CalendarError> {
217        self.client
218            .exchange_refresh_token(token)
219            .request_async(&*REQWEST_CLIENT)
220            .await
221            .map_err(|e| CalendarError::RequestToken(e.to_string()))
222    }
223
224    pub async fn redirect(
225        &self,
226        authorize_url: AuthorizeUrl,
227    ) -> Result<BasicTokenResponse, CalendarError> {
228        let client = self.client.clone();
229        let redirect_port = self.redirect_port;
230        let listener = TcpListener::bind(format!("127.0.0.1:{redirect_port}")).await?;
231        let (mut stream, _) = listener.accept().await?;
232        let mut request_line = String::new();
233        let mut reader = BufReader::new(&mut stream);
234        reader.read_line(&mut request_line).await?;
235
236        let redirect_url = request_line
237            .split_whitespace()
238            .nth(1)
239            .ok_or(CalendarError::RequestToken("Invalid redirect url".into()))?;
240        let url = Url::parse(&("http://localhost".to_string() + redirect_url))
241            .map_err(|e| CalendarError::RequestToken(e.to_string()))?;
242
243        let (_, code_value) =
244            url.query_pairs()
245                .find(|(key, _)| key == "code")
246                .ok_or(CalendarError::RequestToken(
247                    "code query param is missing".into(),
248                ))?;
249        let code = AuthorizationCode::new(code_value.into_owned());
250        let (_, state_value) = url.query_pairs().find(|(key, _)| key == "state").ok_or(
251            CalendarError::RequestToken("state query param is missing".into()),
252        )?;
253        let state = CsrfToken::new(state_value.into_owned());
254        if state.secret() != authorize_url.csrf_token.secret() {
255            return Err(CalendarError::RequestToken(
256                "Received state and csrf token are different".to_string(),
257            ));
258        }
259
260        let message = "Now your i3status-rust calendar is authorized";
261        let response = format!(
262            "HTTP/1.1 200 OK\r\ncontent-length: {}\r\n\r\n{}",
263            message.len(),
264            message
265        );
266        stream.write_all(response.as_bytes()).await?;
267
268        client
269            .exchange_code(code)
270            .set_pkce_verifier(authorize_url.pkce_code_verifier)
271            .request_async(&*REQWEST_CLIENT)
272            .await
273            .map_err(|e| CalendarError::RequestToken(e.to_string()))
274    }
275}
276
277#[derive(Debug)]
278pub enum Authorize {
279    Completed,
280    AskUser(AuthorizeUrl),
281}
282
283#[derive(Debug)]
284pub struct AuthorizeUrl {
285    pkce_code_verifier: PkceCodeVerifier,
286    pub url: Url,
287    csrf_token: CsrfToken,
288}
289
290#[derive(Debug)]
291pub struct TokenStore {
292    path: PathBuf,
293    token: Option<BasicTokenResponse>,
294}
295
296impl TokenStore {
297    pub fn new(path: &Path) -> Self {
298        Self {
299            path: path.into(),
300            token: None,
301        }
302    }
303
304    pub async fn store(&mut self, token: BasicTokenResponse) -> Result<(), TokenStoreError> {
305        let mut file = File::create(&self.path).await?;
306        let value = serde_json::to_string(&token)?;
307        file.write_all(value.as_bytes()).await?;
308        self.token = Some(token);
309        Ok(())
310    }
311
312    pub async fn get(&mut self) -> Option<BasicTokenResponse> {
313        if self.token.is_none()
314            && let Ok(mut file) = File::open(&self.path).await
315        {
316            let mut content = vec![];
317            file.read_to_end(&mut content).await.ok()?;
318            self.token = serde_json::from_slice(&content).ok();
319        }
320        self.token.clone()
321    }
322}
323
324#[derive(thiserror::Error, Debug)]
325pub enum TokenStoreError {
326    #[error(transparent)]
327    Io(#[from] std::io::Error),
328    #[error(transparent)]
329    Serde(#[from] serde_json::Error),
330}