1use std::{collections::HashMap, sync::Arc};
76
77use serde::{Deserialize, Serialize};
78use tokio::sync::RwLock;
79
80use crate::{AirError, Result};
81
82#[derive(Debug, Clone)]
84pub struct TraceGenerator {
85 trace_spans:Arc<RwLock<HashMap<String, TraceSpan>>>,
86 sampling_config:Arc<RwLock<SamplingConfig>>,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct SamplingConfig {
92 pub sample_rate:f64,
94 pub critical_sample_rate:f64,
96 pub max_spans_per_trace:usize,
98 pub trace_ttl_ms:u64,
100}
101
102impl Default for SamplingConfig {
103 fn default() -> Self {
104 Self {
105 sample_rate:0.1, critical_sample_rate:1.0, max_spans_per_trace:1000,
108 trace_ttl_ms:3600000, }
110 }
111}
112
113impl SamplingConfig {
114 pub fn validate(&self) -> Result<()> {
116 if self.sample_rate < 0.0 || self.sample_rate > 1.0 {
117 return Err(crate::AirError::Internal("sample_rate must be between 0.0 and 1.0".to_string()));
118 }
119 if self.critical_sample_rate < 0.0 || self.critical_sample_rate > 1.0 {
120 return Err(crate::AirError::Internal(
121 "critical_sample_rate must be between 0.0 and 1.0".to_string(),
122 ));
123 }
124 if self.max_spans_per_trace == 0 {
125 return Err(crate::AirError::Internal(
126 "max_spans_per_trace must be greater than 0".to_string(),
127 ));
128 }
129 if self.trace_ttl_ms == 0 {
130 return Err(crate::AirError::Internal("trace_ttl_ms must be greater than 0".to_string()));
131 }
132 Ok(())
133 }
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct TraceSpan {
139 pub span_id:String,
140 pub trace_id:String,
141 pub parent_span_id:Option<String>,
142 pub operation_name:String,
143 pub start_time:u64,
144 pub end_time:Option<u64>,
145 pub status:SpanStatus,
146 pub attributes:HashMap<String, String>,
147 pub events:Vec<SpanEvent>,
148 pub error:Option<String>,
149 pub duration_ms:Option<u64>,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
154pub enum SpanStatus {
155 Started,
156 Active,
157 Completed,
158 Failed,
159 Cancelled,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct SpanEvent {
165 pub timestamp:u64,
166 pub name:String,
167 pub attributes:HashMap<String, String>,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct TraceMetadata {
173 pub trace_id:String,
174 pub root_span_id:String,
175 pub total_spans:usize,
176 pub root_operation:String,
177 pub start_time:u64,
178 pub end_time:Option<u64>,
179 pub total_duration_ms:Option<u64>,
180 pub status:TraceStatus,
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
185pub enum TraceStatus {
186 InProgress,
187 Completed,
188 Failed,
189 Cancelled,
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct PropagationContext {
195 pub TraceId:String,
196 pub SpanId:String,
197 pub CorrelationId:String,
198 pub ParentSpanId:Option<String>,
199}
200
201impl TraceGenerator {
202 pub fn new() -> Self {
204 Self {
205 trace_spans:Arc::new(RwLock::new(HashMap::new())),
206 sampling_config:Arc::new(RwLock::new(SamplingConfig::default())),
207 }
208 }
209
210 pub fn with_sampling(sampling_config:SamplingConfig) -> Result<Self> {
212 sampling_config
213 .validate()
214 .map_err(|e| AirError::Internal(format!("Invalid sampling config: {}", e)))?;
215
216 Ok(Self {
217 trace_spans:Arc::new(RwLock::new(HashMap::new())),
218 sampling_config:Arc::new(RwLock::new(sampling_config)),
219 })
220 }
221
222 pub fn generate_trace_id() -> String {
224 std::panic::catch_unwind(|| uuid::Uuid::new_v4().to_string()).unwrap_or_else(|e| {
225 log::error!("[Tracing] Panic in generate_trace_id, using fallback: {:?}", e);
226 format!("{:x}", rand::random::<u64>())
227 })
228 }
229
230 pub fn generate_span_id() -> String {
232 std::panic::catch_unwind(|| uuid::Uuid::new_v4().to_string()).unwrap_or_else(|e| {
233 log::error!("[Tracing] Panic in generate_span_id, using fallback: {:?}", e);
234 format!("{:x}", rand::random::<u64>())
235 })
236 }
237
238 pub async fn should_sample(&self, is_critical:bool) -> bool {
240 let config = self.sampling_config.read().await;
241 let rate = if is_critical { config.critical_sample_rate } else { config.sample_rate };
242
243 rand::random::<f64>() < rate
244 }
245
246 pub fn parse_trace_context(header:&str) -> Result<PropagationContext> {
248 let parts:Vec<&str> = header.split(';').collect();
249
250 let mut trace_id = String::new();
251 let mut parent_span_id = None;
252
253 for part in parts {
254 let kv:Vec<&str> = part.split('=').collect();
255 if kv.len() != 2 {
256 continue;
257 }
258
259 match kv[0].trim() {
260 "traceparent" => {
261 let trace_parent:Vec<&str> = kv[1].trim().split('-').collect();
262 if trace_parent.len() >= 2 {
263 trace_id = trace_parent[1].to_string();
264 if trace_parent.len() >= 3 {
265 parent_span_id = Some(trace_parent[2].to_string());
266 }
267 }
268 },
269 _ => {},
270 }
271 }
272
273 if trace_id.is_empty() {
274 return Err(AirError::Internal("Invalid trace context header".to_string()));
275 }
276
277 Ok(PropagationContext {
278 TraceId:trace_id,
279 SpanId:Self::generate_span_id(),
280 CorrelationId:crate::Utility::GenerateRequestId(),
281 ParentSpanId:parent_span_id,
282 })
283 }
284
285 pub async fn create_span(
287 &self,
288 trace_id:String,
289 operation_name:impl Into<String>,
290 parent_span_id:Option<String>,
291 attributes:Option<HashMap<String, String>>,
292 ) -> Result<TraceSpan> {
293 let span_id = Self::generate_span_id();
294 let operation_name = operation_name.into();
295
296 let span = TraceSpan {
297 span_id:span_id.clone(),
298 trace_id:trace_id.clone(),
299 parent_span_id:parent_span_id.clone(),
300 operation_name:operation_name.clone(),
301 start_time:crate::Utility::CurrentTimestamp(),
302 end_time:None,
303 status:SpanStatus::Started,
304 attributes:attributes.unwrap_or_default(),
305 events:Vec::new(),
306 error:None,
307 duration_ms:None,
308 };
309
310 let mut spans = self.trace_spans.write().await;
311
312 let trace_span_count = spans.values().filter(|s| s.trace_id == trace_id).count();
314
315 let config = self.sampling_config.read().await;
316 if trace_span_count >= config.max_spans_per_trace {
317 log::warn!(
318 "[Tracing] Trace {} exceeds max spans ({}), dropping span {}",
319 trace_id,
320 config.max_spans_per_trace,
321 span_id
322 );
323 return Err(AirError::Internal("Max spans per trace exceeded".to_string()));
324 }
325
326 spans.insert(span_id.clone(), span.clone());
327
328 Ok(span)
329 }
330
331 pub async fn add_span_event(
333 &self,
334 span_id:&str,
335 event_name:impl Into<String>,
336 attributes:HashMap<String, String>,
337 ) -> Result<()> {
338 let event = SpanEvent {
339 timestamp:crate::Utility::CurrentTimestamp(),
340 name:event_name.into(),
341 attributes:self.sanitize_attributes(attributes),
342 };
343
344 let mut spans = self.trace_spans.write().await;
345 if let Some(span) = spans.get_mut(span_id) {
346 span.events.push(event);
347 Ok(())
348 } else {
349 Err(AirError::Internal(format!("Span not found: {}", span_id)))
350 }
351 }
352
353 pub async fn mark_span_active(&self, span_id:&str) -> Result<()> {
355 let mut spans = self.trace_spans.write().await;
356 if let Some(span) = spans.get_mut(span_id) {
357 span.status = SpanStatus::Active;
358 Ok(())
359 } else {
360 Err(AirError::Internal(format!("Span not found: {}", span_id)))
361 }
362 }
363
364 pub async fn complete_span(&self, span_id:&str, error:Option<String>) -> Result<u64> {
366 let Now = crate::Utility::CurrentTimestamp();
367 let mut spans = self.trace_spans.write().await;
368
369 if let Some(span) = spans.get_mut(span_id) {
370 span.end_time = Some(Now);
371 span.duration_ms = Some(Now.saturating_sub(span.start_time));
372 span.status = if error.is_some() { SpanStatus::Failed } else { SpanStatus::Completed };
373 span.error = error.map(|e| self.sanitize_error_message(&e));
374 Ok(span.duration_ms.unwrap_or(0))
375 } else {
376 Err(AirError::Internal(format!("Span not found: {}", span_id)))
377 }
378 }
379
380 pub async fn add_span_attribute(&self, span_id:&str, key:String, value:String) -> Result<()> {
382 self.add_span_attributes(span_id, HashMap::from([(key, value)])).await
383 }
384
385 pub async fn add_span_attributes(&self, span_id:&str, attributes:HashMap<String, String>) -> Result<()> {
387 let sanitized = self.sanitize_attributes(attributes);
388 let mut spans = self.trace_spans.write().await;
389
390 if let Some(span) = spans.get_mut(span_id) {
391 for (key, value) in sanitized {
392 span.attributes.insert(key, value);
393 }
394 Ok(())
395 } else {
396 Err(AirError::Internal(format!("Span not found: {}", span_id)))
397 }
398 }
399
400 pub async fn get_span(&self, span_id:&str) -> Result<TraceSpan> {
402 let spans = self.trace_spans.read().await;
403 spans
404 .get(span_id)
405 .cloned()
406 .ok_or_else(|| AirError::Internal(format!("Span not found: {}", span_id)))
407 }
408
409 pub async fn get_trace_spans(&self, trace_id:&str) -> Result<Vec<TraceSpan>> {
411 let spans = self.trace_spans.read().await;
412 Ok(spans.values().filter(|span| span.trace_id == trace_id).cloned().collect())
413 }
414
415 pub async fn get_trace_metadata(&self, trace_id:&str) -> Result<TraceMetadata> {
417 let trace_spans = self.get_trace_spans(trace_id).await?;
418
419 if trace_spans.is_empty() {
420 return Err(AirError::Internal(format!("Trace not found: {}", trace_id)));
421 }
422
423 let root_span = trace_spans
424 .iter()
425 .find(|s| s.parent_span_id.is_none())
426 .ok_or_else(|| AirError::Internal("No root span found".to_string()))?;
427
428 let total_duration_ms = trace_spans.iter().filter_map(|s| s.duration_ms).max();
429
430 let status = if trace_spans.iter().any(|s| s.status == SpanStatus::Failed) {
431 TraceStatus::Failed
432 } else if trace_spans
433 .iter()
434 .all(|s| s.status == SpanStatus::Completed || s.status == SpanStatus::Failed)
435 {
436 TraceStatus::Completed
437 } else {
438 TraceStatus::InProgress
439 };
440
441 let end_time = trace_spans.iter().filter_map(|s| s.end_time).max();
442
443 Ok(TraceMetadata {
444 trace_id:trace_id.to_string(),
445 root_span_id:root_span.span_id.clone(),
446 total_spans:trace_spans.len(),
447 root_operation:root_span.operation_name.clone(),
448 start_time:root_span.start_time,
449 end_time,
450 total_duration_ms,
451 status,
452 })
453 }
454
455 pub async fn export_trace(&self, trace_id:&str) -> Result<String> {
457 let spans = self.get_trace_spans(trace_id).await?;
458 let metadata = self.get_trace_metadata(trace_id).await?;
459
460 let export = serde_json::json!({
461 "metadata": metadata,
462 "spans": spans,
463 });
464
465 serde_json::to_string_pretty(&export)
466 .map_err(|e| AirError::Serialization(format!("Failed to export trace: {}", e)))
467 }
468
469 pub async fn cleanup_old_spans(&self, older_than_ms:Option<u64>) -> Result<usize> {
471 let Now = crate::Utility::CurrentTimestamp();
472 let ttl = older_than_ms.unwrap_or_else(|| {
473 tokio::task::block_in_place(|| {
474 tokio::runtime::Handle::current().block_on(async { self.sampling_config.read().await.trace_ttl_ms })
475 })
476 });
477
478 let mut spans = self.trace_spans.write().await;
479 let original_len = spans.len();
480
481 spans.retain(|_, span| span.end_time.map_or(true, |end| Now.saturating_sub(end) < ttl));
482
483 Ok(original_len.saturating_sub(spans.len()))
484 }
485
486 pub async fn get_statistics(&self) -> TraceStatistics {
488 let spans = self.trace_spans.read().await;
489
490 let total_traces = spans
491 .values()
492 .map(|s| s.trace_id.clone())
493 .collect::<std::collections::HashSet<_>>()
494 .len();
495
496 let completed_spans = spans.values().filter(|s| s.status == SpanStatus::Completed).count();
497
498 let failed_spans = spans.values().filter(|s| s.status == SpanStatus::Failed).count();
499
500 let in_progress_spans = spans
501 .values()
502 .filter(|s| s.status == SpanStatus::Started || s.status == SpanStatus::Active)
503 .count();
504
505 TraceStatistics {
506 total_traces:total_traces as u64,
507 total_spans:spans.len() as u64,
508 completed_spans:completed_spans as u64,
509 failed_spans:failed_spans as u64,
510 in_progress_spans:in_progress_spans as u64,
511 }
512 }
513
514 fn sanitize_attributes(&self, mut attributes:HashMap<String, String>) -> HashMap<String, String> {
516 let sensitive_keys = vec![
517 "password",
518 "token",
519 "secret",
520 "api_key",
521 "authorization",
522 "credential",
523 "auth",
524 "private_key",
525 "session_token",
526 ];
527
528 let attr_keys:Vec<String> = attributes.keys().cloned().collect();
530
531 for key in sensitive_keys {
532 let key_lower = key.to_lowercase();
533 for attr_key in &attr_keys {
534 if attr_key.to_lowercase().contains(&key_lower) {
535 attributes.insert(attr_key.clone(), "[REDACTED]".to_string());
536 }
537 }
538 }
539
540 attributes
541 }
542
543 fn sanitize_error_message(&self, message:&str) -> String {
545 let mut sanitized = message.to_string();
546
547 let patterns = vec![
548 (r"(?i)password[=:]\S+", "password=[REDACTED]"),
549 (r"(?i)token[=:]\S+", "token=[REDACTED]"),
550 (r"(?i)secret[=:]\S+", "secret=[REDACTED]"),
551 (r"(?i)(api|private)[_-]?key[=:]\S+", "api_key=[REDACTED]"),
552 (
553 r"(?i)authorization[=[:space:]]+Bearer[[:space:]]+\S+",
554 "Authorization: Bearer [REDACTED]",
555 ),
556 ];
557
558 for (pattern, replacement) in patterns {
559 if let Ok(re) = regex::Regex::new(pattern) {
560 sanitized = re.replace_all(&sanitized, replacement).to_string();
561 }
562 }
563
564 sanitized
565 }
566}
567
568#[derive(Debug, Clone, Serialize, Deserialize)]
570pub struct TraceStatistics {
571 pub total_traces:u64,
572 pub total_spans:u64,
573 pub completed_spans:u64,
574 pub failed_spans:u64,
575 pub in_progress_spans:u64,
576}
577
578impl Default for TraceGenerator {
579 fn default() -> Self { Self::new() }
580}
581
582static TRACE_GENERATOR:std::sync::OnceLock<TraceGenerator> = std::sync::OnceLock::new();
584
585pub fn get_trace_generator() -> &'static TraceGenerator { TRACE_GENERATOR.get_or_init(TraceGenerator::new) }
587
588pub fn initialize_tracing(sampling_config:Option<SamplingConfig>) -> Result<()> {
590 let generator = if let Some(config) = sampling_config {
591 TraceGenerator::with_sampling(config)?
592 } else {
593 TraceGenerator::new()
594 };
595
596 let _old = TRACE_GENERATOR.set(generator);
597 log::info!("[Tracing] Trace generator initialized with tracing");
598 Ok(())
599}
600
601thread_local! {
602 static PROPAGATION_CONTEXT: std::cell::RefCell<Option<PropagationContext>> = std::cell::RefCell::new(None);
603}
604
605pub fn set_propagation_context(context:PropagationContext) {
607 PROPAGATION_CONTEXT.with(|ctx| {
608 *ctx.borrow_mut() = Some(context);
609 });
610}
611
612pub fn get_propagation_context() -> Option<PropagationContext> { PROPAGATION_CONTEXT.with(|ctx| ctx.borrow().clone()) }
614
615pub async fn create_propagation_context(TraceId:String, ParentSpanId:Option<String>) -> PropagationContext {
617 let SpanId = TraceGenerator::generate_span_id();
618 let CorrelationId = crate::Utility::GenerateRequestId();
619
620 PropagationContext { TraceId, SpanId, CorrelationId, ParentSpanId }
621}
622
623pub fn create_trace_context_header(context:&PropagationContext) -> String {
625 format!("traceparent=00-{}-{}-01", context.TraceId, context.SpanId)
626}