// Package stats provides session-level statistics tracking with lock-free
// concurrent access. TypeStats tracks per-vulnerability-type finding counts,
// while SessionStats provides aggregate metrics with rate and ETA calculation.
package stats

import (
	"sync"
	"sync/atomic"
	"time"
)

// TypeStats tracks finding counts by vulnerability type (e.g., ".git", ".env", "wp-config").
// Uses sync.Map for lock-free concurrent access - optimized for disjoint key writes.
type TypeStats struct {
	counts sync.Map // map[string]*atomic.Int64
}

// NewTypeStats creates a new TypeStats instance for tracking per-type finding counts.
func NewTypeStats() *TypeStats {
	return &TypeStats{}
}

// Record increments the count for the given vulnerability type.
// Safe to call from multiple goroutines concurrently.
func (t *TypeStats) Record(vulnType string) {
	// Load or create counter atomically
	actual, _ := t.counts.LoadOrStore(vulnType, &atomic.Int64{})
	counter := actual.(*atomic.Int64)
	counter.Add(1)
}

// Snapshot returns a copy of all type counts.
// Returns map[string]int64 for easy iteration.
func (t *TypeStats) Snapshot() map[string]int64 {
	result := make(map[string]int64)
	t.counts.Range(func(key, value any) bool {
		result[key.(string)] = value.(*atomic.Int64).Load()
		return true
	})
	return result
}

// Reset clears all counters. Call when starting a new scan.
func (t *TypeStats) Reset() {
	t.counts.Range(func(key, _ any) bool {
		t.counts.Delete(key)
		return true
	})
}

// SessionStats tracks aggregate session metrics with rate and ETA calculation.
type SessionStats struct {
	// Counters - atomic for concurrent updates
	Scanned      atomic.Int64
	Found        atomic.Int64
	Errors       atomic.Int64
	TotalTargets atomic.Int64

	// Timing
	startTime  time.Time
	lastUpdate time.Time

	// Rate smoothing (exponential moving average)
	smoothedRate float64      // domains per second
	mu           sync.RWMutex // protects smoothedRate only
}

// NewSessionStats creates a new SessionStats instance for tracking aggregate scan metrics.
// Pass totalTargets as 0 if unknown (ETA and Progress will return -1).
func NewSessionStats(totalTargets int) *SessionStats {
	s := &SessionStats{
		startTime:  time.Now(),
		lastUpdate: time.Now(),
	}
	s.TotalTargets.Store(int64(totalTargets))
	return s
}

// RecordScanned increments scanned counter and updates rate.
func (s *SessionStats) RecordScanned() {
	s.Scanned.Add(1)
	s.updateRate()
}

// RecordFound increments found counter.
func (s *SessionStats) RecordFound() {
	s.Found.Add(1)
}

// RecordError increments error counter.
func (s *SessionStats) RecordError() {
	s.Errors.Add(1)
}

// updateRate recalculates smoothed rate using exponential moving average.
// Alpha = 0.3 gives good balance between responsiveness and stability.
func (s *SessionStats) updateRate() {
	s.mu.Lock()
	defer s.mu.Unlock()

	elapsed := time.Since(s.startTime).Seconds()
	if elapsed < 0.1 {
		return // avoid division by zero or unstable early rates
	}

	scanned := float64(s.Scanned.Load())
	instantRate := scanned / elapsed

	// Exponential smoothing: new = alpha * instant + (1-alpha) * old
	const alpha = 0.3
	if s.smoothedRate == 0 {
		s.smoothedRate = instantRate
	} else {
		s.smoothedRate = alpha*instantRate + (1-alpha)*s.smoothedRate
	}
	s.lastUpdate = time.Now()
}

// Rate returns the smoothed scan rate in domains per second.
func (s *SessionStats) Rate() float64 {
	s.mu.RLock()
	defer s.mu.RUnlock()
	return s.smoothedRate
}

// ETA returns estimated time remaining, or -1 if unknown.
// Returns -1 if total targets unknown (0) or rate is zero.
func (s *SessionStats) ETA() time.Duration {
	total := s.TotalTargets.Load()
	if total == 0 {
		return -1
	}

	remaining := total - s.Scanned.Load()
	if remaining <= 0 {
		return 0
	}

	rate := s.Rate()
	if rate < 0.001 {
		return -1
	}

	seconds := float64(remaining) / rate
	return time.Duration(seconds * float64(time.Second))
}

// Elapsed returns time since session started.
func (s *SessionStats) Elapsed() time.Duration {
	return time.Since(s.startTime)
}

// Progress returns completion percentage (0-100), or -1 if unknown.
func (s *SessionStats) Progress() float64 {
	total := s.TotalTargets.Load()
	if total == 0 {
		return -1
	}
	return float64(s.Scanned.Load()) / float64(total) * 100
}

// SessionStatsSnapshot is an immutable point-in-time view of session statistics.
type SessionStatsSnapshot struct {
	Scanned      int64
	Found        int64
	Errors       int64
	TotalTargets int64
	Rate         float64       // domains/sec
	ETA          time.Duration // -1 if unknown
	Elapsed      time.Duration
	Progress     float64 // percentage, -1 if unknown
}

// Snapshot returns current stats as a value struct.
func (s *SessionStats) Snapshot() SessionStatsSnapshot {
	return SessionStatsSnapshot{
		Scanned:      s.Scanned.Load(),
		Found:        s.Found.Load(),
		Errors:       s.Errors.Load(),
		TotalTargets: s.TotalTargets.Load(),
		Rate:         s.Rate(),
		ETA:          s.ETA(),
		Elapsed:      s.Elapsed(),
		Progress:     s.Progress(),
	}
}

// Reset resets all counters for a new session.
func (s *SessionStats) Reset(totalTargets int) {
	s.Scanned.Store(0)
	s.Found.Store(0)
	s.Errors.Store(0)
	s.TotalTargets.Store(int64(totalTargets))
	s.startTime = time.Now()
	s.lastUpdate = time.Now()
	s.mu.Lock()
	s.smoothedRate = 0
	s.mu.Unlock()
}
