i3status_rs/blocks/calendar/
auth.rs

1use base64::Engine as _;
2use oauth2::basic::{BasicClient, BasicTokenType};
3use oauth2::reqwest::async_http_client;
4use oauth2::{
5    AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields,
6    PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope, StandardTokenResponse,
7    TokenResponse as _, TokenUrl,
8};
9use reqwest::Url;
10use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue};
11use std::path::{Path, PathBuf};
12use tokio::fs::File;
13use tokio::io::{AsyncBufReadExt as _, AsyncReadExt as _, AsyncWriteExt as _, BufReader};
14use tokio::net::TcpListener;
15
16use super::CalendarError;
17
18pub enum Auth {
19    Unauthenticated,
20    Basic(Basic),
21    OAuth2(Box<OAuth2>),
22}
23
24impl Auth {
25    pub fn oauth2(flow: OAuth2Flow, token_store: TokenStore, scopes: Vec<Scope>) -> Self {
26        Self::OAuth2(Box::new(OAuth2 {
27            flow,
28            token_store,
29            scopes,
30        }))
31    }
32    pub fn basic(username: String, password: String) -> Self {
33        Self::Basic(Basic { username, password })
34    }
35    pub async fn headers(&mut self) -> HeaderMap {
36        match self {
37            Auth::Unauthenticated => HeaderMap::new(),
38            Auth::Basic(auth) => auth.headers().await,
39            Auth::OAuth2(auth) => auth.headers().await,
40        }
41    }
42
43    pub async fn handle_error(&mut self, error: reqwest::Error) -> Result<(), CalendarError> {
44        match self {
45            Auth::Unauthenticated | Auth::Basic(_) => Err(CalendarError::Http(error)),
46            Auth::OAuth2(auth) => auth.handle_error(error).await,
47        }
48    }
49
50    pub async fn authorize(&mut self) -> Result<Authorize, CalendarError> {
51        match self {
52            Auth::Unauthenticated | Auth::Basic(_) => Ok(Authorize::Completed),
53            Auth::OAuth2(auth) => Ok(Authorize::AskUser(auth.authorize().await?)),
54        }
55    }
56    pub async fn ask_user(&mut self, authorize_url: AuthorizeUrl) -> Result<(), CalendarError> {
57        match self {
58            Auth::Unauthenticated | Auth::Basic(_) => Ok(()),
59            Auth::OAuth2(auth) => auth.ask_user(authorize_url).await,
60        }
61    }
62}
63
64pub struct Basic {
65    username: String,
66    password: String,
67}
68
69impl Basic {
70    pub async fn headers(&mut self) -> HeaderMap {
71        let mut headers = HeaderMap::new();
72        let header =
73            base64::prelude::BASE64_STANDARD.encode(format!("{}:{}", self.username, self.password));
74        let mut header_value = HeaderValue::from_str(format!("Basic {header}").as_str())
75            .expect("A valid basic header");
76        header_value.set_sensitive(true);
77        headers.insert(AUTHORIZATION, header_value);
78        headers
79    }
80}
81
82pub struct OAuth2 {
83    flow: OAuth2Flow,
84    token_store: TokenStore,
85    scopes: Vec<Scope>,
86}
87
88impl OAuth2 {
89    pub async fn headers(&mut self) -> HeaderMap {
90        let mut headers = HeaderMap::new();
91        if let Some(token) = self.token_store.get().await {
92            let mut auth_value =
93                HeaderValue::from_str(format!("Bearer {}", token.access_token().secret()).as_str())
94                    .expect("A valid access token");
95            auth_value.set_sensitive(true);
96            headers.insert(AUTHORIZATION, auth_value);
97        }
98        headers
99    }
100
101    async fn handle_error(&mut self, error: reqwest::Error) -> Result<(), CalendarError> {
102        if let Some(status) = error.status() {
103            if status == 401 {
104                match self
105                    .token_store
106                    .get()
107                    .await
108                    .and_then(|t| t.refresh_token().cloned())
109                {
110                    Some(refresh_token) => {
111                        let mut token = self.flow.refresh_token_exchange(&refresh_token).await?;
112                        if token.refresh_token().is_none() {
113                            token.set_refresh_token(Some(refresh_token));
114                        }
115                        self.token_store.store(token).await?;
116                        return Ok(());
117                    }
118                    None => return Err(CalendarError::AuthRequired),
119                }
120            }
121            if status == 403 {
122                return Err(CalendarError::AuthRequired);
123            }
124        }
125        Err(CalendarError::Http(error))
126    }
127
128    async fn authorize(&mut self) -> Result<AuthorizeUrl, CalendarError> {
129        Ok(self.flow.authorize_url(self.scopes.clone()))
130    }
131
132    async fn ask_user(&mut self, authorize_url: AuthorizeUrl) -> Result<(), CalendarError> {
133        let token = self.flow.redirect(authorize_url).await?;
134        self.token_store.store(token).await?;
135        Ok(())
136    }
137}
138pub struct OAuth2Flow {
139    client: BasicClient,
140    redirect_port: u16,
141}
142
143impl OAuth2Flow {
144    pub fn new(
145        client_id: ClientId,
146        client_secret: ClientSecret,
147        auth_url: AuthUrl,
148        token_url: TokenUrl,
149        redirect_port: u16,
150    ) -> Self {
151        Self {
152            client: BasicClient::new(client_id, Some(client_secret), auth_url, Some(token_url))
153                .set_redirect_uri(
154                    RedirectUrl::new(format!("http://localhost:{redirect_port}").to_string())
155                        .expect("A valid redirect URL"),
156                ),
157            redirect_port,
158        }
159    }
160
161    pub fn authorize_url(&self, scopes: Vec<Scope>) -> AuthorizeUrl {
162        let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256();
163        let (authorize_url, csrf_token) = self
164            .client
165            .authorize_url(CsrfToken::new_random)
166            .add_scopes(scopes)
167            .set_pkce_challenge(pkce_code_challenge.clone())
168            .url();
169        AuthorizeUrl {
170            pkce_code_verifier,
171            url: authorize_url,
172            csrf_token,
173        }
174    }
175
176    pub async fn refresh_token_exchange(
177        &self,
178        token: &RefreshToken,
179    ) -> Result<OAuth2TokenResponse, CalendarError> {
180        self.client
181            .exchange_refresh_token(token)
182            .request_async(async_http_client)
183            .await
184            .map_err(|e| CalendarError::RequestToken(e.to_string()))
185    }
186
187    pub async fn redirect(
188        &self,
189        authorize_url: AuthorizeUrl,
190    ) -> Result<OAuth2TokenResponse, CalendarError> {
191        let client = self.client.clone();
192        let redirect_port = self.redirect_port;
193        let listener = TcpListener::bind(format!("127.0.0.1:{}", redirect_port)).await?;
194        let (mut stream, _) = listener.accept().await?;
195        let mut request_line = String::new();
196        let mut reader = BufReader::new(&mut stream);
197        reader.read_line(&mut request_line).await?;
198
199        let redirect_url = request_line
200            .split_whitespace()
201            .nth(1)
202            .ok_or(CalendarError::RequestToken("Invalid redirect url".into()))?;
203        let url = Url::parse(&("http://localhost".to_string() + redirect_url))
204            .map_err(|e| CalendarError::RequestToken(e.to_string()))?;
205
206        let (_, code_value) =
207            url.query_pairs()
208                .find(|(key, _)| key == "code")
209                .ok_or(CalendarError::RequestToken(
210                    "code query param is missing".into(),
211                ))?;
212        let code = AuthorizationCode::new(code_value.into_owned());
213        let (_, state_value) = url.query_pairs().find(|(key, _)| key == "state").ok_or(
214            CalendarError::RequestToken("state query param is missing".into()),
215        )?;
216        let state = CsrfToken::new(state_value.into_owned());
217        if state.secret() != authorize_url.csrf_token.secret() {
218            return Err(CalendarError::RequestToken(
219                "Received state and csrf token are different".to_string(),
220            ));
221        }
222
223        let message = "Now your i3status-rust calendar is authorized";
224        let response = format!(
225            "HTTP/1.1 200 OK\r\ncontent-length: {}\r\n\r\n{}",
226            message.len(),
227            message
228        );
229        stream.write_all(response.as_bytes()).await?;
230
231        client
232            .exchange_code(code)
233            .set_pkce_verifier(authorize_url.pkce_code_verifier)
234            .request_async(async_http_client)
235            .await
236            .map_err(|e| CalendarError::RequestToken(e.to_string()))
237    }
238}
239
240#[derive(Debug)]
241pub enum Authorize {
242    Completed,
243    AskUser(AuthorizeUrl),
244}
245
246#[derive(Debug)]
247pub struct AuthorizeUrl {
248    pkce_code_verifier: PkceCodeVerifier,
249    pub url: Url,
250    csrf_token: CsrfToken,
251}
252
253type OAuth2TokenResponse = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
254
255#[derive(Debug)]
256pub struct TokenStore {
257    path: PathBuf,
258    token: Option<OAuth2TokenResponse>,
259}
260
261impl TokenStore {
262    pub fn new(path: &Path) -> Self {
263        Self {
264            path: path.into(),
265            token: None,
266        }
267    }
268
269    pub async fn store(&mut self, token: OAuth2TokenResponse) -> Result<(), TokenStoreError> {
270        let mut file = File::create(&self.path).await?;
271        let value = serde_json::to_string(&token)?;
272        file.write_all(value.as_bytes()).await?;
273        self.token = Some(token);
274        Ok(())
275    }
276
277    pub async fn get(&mut self) -> Option<OAuth2TokenResponse> {
278        if self.token.is_none() {
279            if let Ok(mut file) = File::open(&self.path).await {
280                let mut content = vec![];
281                file.read_to_end(&mut content).await.ok()?;
282                self.token = serde_json::from_slice(&content).ok();
283            }
284        }
285        self.token.clone()
286    }
287}
288
289#[derive(thiserror::Error, Debug)]
290pub enum TokenStoreError {
291    #[error(transparent)]
292    Io(#[from] std::io::Error),
293    #[error(transparent)]
294    Serde(#[from] serde_json::Error),
295}