i3status_rs/blocks/calendar/
auth.rs1use base64::Engine as _;
2use oauth2::basic::{
3 BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse,
4 BasicTokenResponse,
5};
6use oauth2::{
7 AuthUrl, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken, EndpointNotSet,
8 EndpointSet, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope,
9 StandardRevocableToken, TokenResponse as _, TokenUrl,
10};
11use reqwest;
12use reqwest::Url;
13use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue};
14use std::path::{Path, PathBuf};
15use std::sync::LazyLock;
16use tokio::fs::File;
17use tokio::io::{AsyncBufReadExt as _, AsyncReadExt as _, AsyncWriteExt as _, BufReader};
18use tokio::net::TcpListener;
19
20use super::CalendarError;
21use crate::{APP_USER_AGENT, REQWEST_TIMEOUT};
22
23static REQWEST_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
24 reqwest::Client::builder()
25 .user_agent(APP_USER_AGENT)
26 .timeout(REQWEST_TIMEOUT)
27 .redirect(reqwest::redirect::Policy::none())
29 .build()
30 .unwrap()
31});
32
33type BasicClient<
34 HasAuthUrl = EndpointSet,
35 HasDeviceAuthUrl = EndpointNotSet,
36 HasIntrospectionUrl = EndpointNotSet,
37 HasRevocationUrl = EndpointNotSet,
38 HasTokenUrl = EndpointSet,
39> = Client<
40 BasicErrorResponse,
41 BasicTokenResponse,
42 BasicTokenIntrospectionResponse,
43 StandardRevocableToken,
44 BasicRevocationErrorResponse,
45 HasAuthUrl,
46 HasDeviceAuthUrl,
47 HasIntrospectionUrl,
48 HasRevocationUrl,
49 HasTokenUrl,
50>;
51
52pub enum Auth {
53 Unauthenticated,
54 Basic(Basic),
55 OAuth2(Box<OAuth2>),
56}
57
58impl Auth {
59 pub fn oauth2(flow: OAuth2Flow, token_store: TokenStore, scopes: Vec<Scope>) -> Self {
60 Self::OAuth2(Box::new(OAuth2 {
61 flow,
62 token_store,
63 scopes,
64 }))
65 }
66 pub fn basic(username: String, password: String) -> Self {
67 Self::Basic(Basic { username, password })
68 }
69 pub async fn headers(&mut self) -> HeaderMap {
70 match self {
71 Auth::Unauthenticated => HeaderMap::new(),
72 Auth::Basic(auth) => auth.headers().await,
73 Auth::OAuth2(auth) => auth.headers().await,
74 }
75 }
76
77 pub async fn handle_error(&mut self, error: reqwest::Error) -> Result<(), CalendarError> {
78 match self {
79 Auth::Unauthenticated | Auth::Basic(_) => Err(CalendarError::Http(error)),
80 Auth::OAuth2(auth) => auth.handle_error(error).await,
81 }
82 }
83
84 pub async fn authorize(&mut self) -> Result<Authorize, CalendarError> {
85 match self {
86 Auth::Unauthenticated | Auth::Basic(_) => Ok(Authorize::Completed),
87 Auth::OAuth2(auth) => Ok(Authorize::AskUser(auth.authorize().await?)),
88 }
89 }
90 pub async fn ask_user(&mut self, authorize_url: AuthorizeUrl) -> Result<(), CalendarError> {
91 match self {
92 Auth::Unauthenticated | Auth::Basic(_) => Ok(()),
93 Auth::OAuth2(auth) => auth.ask_user(authorize_url).await,
94 }
95 }
96}
97
98pub struct Basic {
99 username: String,
100 password: String,
101}
102
103impl Basic {
104 pub async fn headers(&mut self) -> HeaderMap {
105 let mut headers = HeaderMap::new();
106 let header =
107 base64::prelude::BASE64_STANDARD.encode(format!("{}:{}", self.username, self.password));
108 let mut header_value = HeaderValue::from_str(format!("Basic {header}").as_str())
109 .expect("A valid basic header");
110 header_value.set_sensitive(true);
111 headers.insert(AUTHORIZATION, header_value);
112 headers
113 }
114}
115
116pub struct OAuth2 {
117 flow: OAuth2Flow,
118 token_store: TokenStore,
119 scopes: Vec<Scope>,
120}
121
122impl OAuth2 {
123 pub async fn headers(&mut self) -> HeaderMap {
124 let mut headers = HeaderMap::new();
125 if let Some(token) = self.token_store.get().await {
126 let mut auth_value =
127 HeaderValue::from_str(format!("Bearer {}", token.access_token().secret()).as_str())
128 .expect("A valid access token");
129 auth_value.set_sensitive(true);
130 headers.insert(AUTHORIZATION, auth_value);
131 }
132 headers
133 }
134
135 async fn handle_error(&mut self, error: reqwest::Error) -> Result<(), CalendarError> {
136 if let Some(status) = error.status() {
137 if status == 401 {
138 match self
139 .token_store
140 .get()
141 .await
142 .and_then(|t| t.refresh_token().cloned())
143 {
144 Some(refresh_token) => {
145 let mut token = self.flow.refresh_token_exchange(&refresh_token).await?;
146 if token.refresh_token().is_none() {
147 token.set_refresh_token(Some(refresh_token));
148 }
149 self.token_store.store(token).await?;
150 return Ok(());
151 }
152 None => return Err(CalendarError::AuthRequired),
153 }
154 }
155 if status == 403 {
156 return Err(CalendarError::AuthRequired);
157 }
158 }
159 Err(CalendarError::Http(error))
160 }
161
162 async fn authorize(&mut self) -> Result<AuthorizeUrl, CalendarError> {
163 Ok(self.flow.authorize_url(self.scopes.clone()))
164 }
165
166 async fn ask_user(&mut self, authorize_url: AuthorizeUrl) -> Result<(), CalendarError> {
167 let token = self.flow.redirect(authorize_url).await?;
168 self.token_store.store(token).await?;
169 Ok(())
170 }
171}
172pub struct OAuth2Flow {
173 client: BasicClient,
174 redirect_port: u16,
175}
176
177impl OAuth2Flow {
178 pub fn new(
179 client_id: ClientId,
180 client_secret: ClientSecret,
181 auth_url: AuthUrl,
182 token_url: TokenUrl,
183 redirect_port: u16,
184 ) -> Self {
185 Self {
186 client: BasicClient::new(client_id)
187 .set_client_secret(client_secret)
188 .set_auth_uri(auth_url)
189 .set_token_uri(token_url)
190 .set_redirect_uri(
191 RedirectUrl::new(format!("http://localhost:{redirect_port}").to_string())
192 .expect("A valid redirect URL"),
193 ),
194 redirect_port,
195 }
196 }
197
198 pub fn authorize_url(&self, scopes: Vec<Scope>) -> AuthorizeUrl {
199 let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256();
200 let (authorize_url, csrf_token) = self
201 .client
202 .authorize_url(CsrfToken::new_random)
203 .add_scopes(scopes)
204 .set_pkce_challenge(pkce_code_challenge.clone())
205 .url();
206 AuthorizeUrl {
207 pkce_code_verifier,
208 url: authorize_url,
209 csrf_token,
210 }
211 }
212
213 pub async fn refresh_token_exchange(
214 &self,
215 token: &RefreshToken,
216 ) -> Result<BasicTokenResponse, CalendarError> {
217 self.client
218 .exchange_refresh_token(token)
219 .request_async(&*REQWEST_CLIENT)
220 .await
221 .map_err(|e| CalendarError::RequestToken(e.to_string()))
222 }
223
224 pub async fn redirect(
225 &self,
226 authorize_url: AuthorizeUrl,
227 ) -> Result<BasicTokenResponse, CalendarError> {
228 let client = self.client.clone();
229 let redirect_port = self.redirect_port;
230 let listener = TcpListener::bind(format!("127.0.0.1:{redirect_port}")).await?;
231 let (mut stream, _) = listener.accept().await?;
232 let mut request_line = String::new();
233 let mut reader = BufReader::new(&mut stream);
234 reader.read_line(&mut request_line).await?;
235
236 let redirect_url = request_line
237 .split_whitespace()
238 .nth(1)
239 .ok_or(CalendarError::RequestToken("Invalid redirect url".into()))?;
240 let url = Url::parse(&("http://localhost".to_string() + redirect_url))
241 .map_err(|e| CalendarError::RequestToken(e.to_string()))?;
242
243 let (_, code_value) =
244 url.query_pairs()
245 .find(|(key, _)| key == "code")
246 .ok_or(CalendarError::RequestToken(
247 "code query param is missing".into(),
248 ))?;
249 let code = AuthorizationCode::new(code_value.into_owned());
250 let (_, state_value) = url.query_pairs().find(|(key, _)| key == "state").ok_or(
251 CalendarError::RequestToken("state query param is missing".into()),
252 )?;
253 let state = CsrfToken::new(state_value.into_owned());
254 if state.secret() != authorize_url.csrf_token.secret() {
255 return Err(CalendarError::RequestToken(
256 "Received state and csrf token are different".to_string(),
257 ));
258 }
259
260 let message = "Now your i3status-rust calendar is authorized";
261 let response = format!(
262 "HTTP/1.1 200 OK\r\ncontent-length: {}\r\n\r\n{}",
263 message.len(),
264 message
265 );
266 stream.write_all(response.as_bytes()).await?;
267
268 client
269 .exchange_code(code)
270 .set_pkce_verifier(authorize_url.pkce_code_verifier)
271 .request_async(&*REQWEST_CLIENT)
272 .await
273 .map_err(|e| CalendarError::RequestToken(e.to_string()))
274 }
275}
276
277#[derive(Debug)]
278pub enum Authorize {
279 Completed,
280 AskUser(AuthorizeUrl),
281}
282
283#[derive(Debug)]
284pub struct AuthorizeUrl {
285 pkce_code_verifier: PkceCodeVerifier,
286 pub url: Url,
287 csrf_token: CsrfToken,
288}
289
290#[derive(Debug)]
291pub struct TokenStore {
292 path: PathBuf,
293 token: Option<BasicTokenResponse>,
294}
295
296impl TokenStore {
297 pub fn new(path: &Path) -> Self {
298 Self {
299 path: path.into(),
300 token: None,
301 }
302 }
303
304 pub async fn store(&mut self, token: BasicTokenResponse) -> Result<(), TokenStoreError> {
305 let mut file = File::create(&self.path).await?;
306 let value = serde_json::to_string(&token)?;
307 file.write_all(value.as_bytes()).await?;
308 self.token = Some(token);
309 Ok(())
310 }
311
312 pub async fn get(&mut self) -> Option<BasicTokenResponse> {
313 if self.token.is_none()
314 && let Ok(mut file) = File::open(&self.path).await
315 {
316 let mut content = vec![];
317 file.read_to_end(&mut content).await.ok()?;
318 self.token = serde_json::from_slice(&content).ok();
319 }
320 self.token.clone()
321 }
322}
323
324#[derive(thiserror::Error, Debug)]
325pub enum TokenStoreError {
326 #[error(transparent)]
327 Io(#[from] std::io::Error),
328 #[error(transparent)]
329 Serde(#[from] serde_json::Error),
330}