diff --git a/blacklist.go b/blacklist.go index b705c0c..c5628ae 100644 --- a/blacklist.go +++ b/blacklist.go @@ -127,6 +127,7 @@ func extractIP(remoteAddr string, logger *zap.Logger) string { return host } +// LoadIPBlacklistFromFile loads IP addresses from a file into the provided map. // LoadIPBlacklistFromFile loads IP addresses from a file into the provided map. func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[string]struct{}) error { bl.logger.Debug("Loading IP blacklist", zap.String("path", path)) @@ -141,6 +142,7 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[ scanner := bufio.NewScanner(file) validEntries := 0 totalLines := 0 + invalidEntries := 0 for scanner.Scan() { totalLines++ @@ -156,6 +158,9 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[ zap.Int("line", totalLines), zap.String("entry", line), ) + invalidEntries++ + // If you want the entire load to fail if any single IP entry is invalid, uncomment the line below + // return fmt.Errorf("failed to add IP entry %s : %w", line, err) } else { validEntries++ } @@ -169,6 +174,7 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[ bl.logger.Info("IP blacklist loaded", zap.String("path", path), zap.Int("valid_entries", validEntries), + zap.Int("invalid_entries", invalidEntries), zap.Int("total_lines", totalLines), ) return nil diff --git a/caddywaf.go b/caddywaf.go index 2dc99e8..15dad2a 100644 --- a/caddywaf.go +++ b/caddywaf.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "net" "net/http" "os" "strings" @@ -29,6 +28,7 @@ var ( _ caddy.Provisioner = (*Middleware)(nil) _ caddyhttp.MiddlewareHandler = (*Middleware)(nil) _ caddyfile.Unmarshaler = (*Middleware)(nil) + _ caddy.Validator = (*Middleware)(nil) ) // ==================== Initialization and Setup ==================== @@ -209,19 +209,18 @@ func (m *Middleware) Provision(ctx caddy.Context) error { } // Load IP blacklist - m.ipBlacklist = NewCIDRTrie() - m.logger.Debug("ipBlacklist initialized in Provision", zap.Bool("isNil", m.ipBlacklist == nil)) if m.IPBlacklistFile != "" { - err = m.loadIPBlacklistIntoMap(m.IPBlacklistFile, m.ipBlacklist) + m.ipBlacklist = NewCIDRTrie() + err = m.loadIPBlacklist(m.IPBlacklistFile, m.ipBlacklist) if err != nil { return fmt.Errorf("failed to load IP blacklist: %w", err) } } // Load DNS blacklist - m.dnsBlacklist = make(map[string]struct{}) // Changed to map[string]struct{} if m.DNSBlacklistFile != "" { - err = m.blacklistLoader.LoadDNSBlacklistFromFile(m.DNSBlacklistFile, m.dnsBlacklist) + m.dnsBlacklist = make(map[string]struct{}) + err = m.loadDNSBlacklist(m.DNSBlacklistFile, m.dnsBlacklist) if err != nil { return fmt.Errorf("failed to load DNS blacklist: %w", err) } @@ -414,25 +413,21 @@ func (m *Middleware) ReloadConfig() error { defer m.mu.Unlock() m.logger.Info("Reloading WAF configuration") - - newIPBlacklist := NewCIDRTrie() if m.IPBlacklistFile != "" { - if err := m.loadIPBlacklistIntoMap(m.IPBlacklistFile, newIPBlacklist); err != nil { + newIPBlacklist := NewCIDRTrie() + if err := m.loadIPBlacklist(m.IPBlacklistFile, newIPBlacklist); err != nil { m.logger.Error("Failed to reload IP blacklist", zap.String("file", m.IPBlacklistFile), zap.Error(err)) return fmt.Errorf("failed to reload IP blacklist: %v", err) } - } else { - m.logger.Debug("No IP blacklist file specified, skipping reload") + m.ipBlacklist = newIPBlacklist } - - newDNSBlacklist := make(map[string]struct{}) if m.DNSBlacklistFile != "" { - if err := m.loadDNSBlacklistIntoMap(m.DNSBlacklistFile, newDNSBlacklist); err != nil { + newDNSBlacklist := make(map[string]struct{}) + if err := m.loadDNSBlacklist(m.DNSBlacklistFile, newDNSBlacklist); err != nil { m.logger.Error("Failed to reload DNS blacklist", zap.String("file", m.DNSBlacklistFile), zap.Error(err)) return fmt.Errorf("failed to reload DNS blacklist: %v", err) } - } else { - m.logger.Debug("No DNS blacklist file specified, skipping reload") + m.dnsBlacklist = newDNSBlacklist } // Call the external loadRules function @@ -441,62 +436,38 @@ func (m *Middleware) ReloadConfig() error { return fmt.Errorf("failed to reload rules: %v", err) } - m.ipBlacklist = newIPBlacklist - m.dnsBlacklist = newDNSBlacklist - m.logger.Info("WAF configuration reloaded successfully") - return nil } -func (m *Middleware) loadIPBlacklistIntoMap(path string, blacklistMap *CIDRTrie) error { - content, err := os.ReadFile(path) - if err != nil { - return fmt.Errorf("failed to read IP blacklist file: %v", err) +func (m *Middleware) loadIPBlacklist(path string, blacklistMap *CIDRTrie) error { + if _, err := os.Stat(path); os.IsNotExist(err) { + m.logger.Warn("Skipping IP blacklist load, file does not exist", zap.String("file", path)) + return nil } - lines := strings.Split(string(content), "\n") - for _, line := range lines { - line = strings.TrimSpace(line) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - - if !strings.Contains(line, "/") { - // Handle single IP addresses - ip := net.ParseIP(line) - if ip == nil { - m.logger.Warn("Skipping invalid IP address format in blacklist", zap.String("address", line)) - continue - } - - if ip.To4() != nil { - line = line + "/32" - } else { - line = line + "/128" - } - } + blacklist := make(map[string]struct{}) + err := m.blacklistLoader.LoadIPBlacklistFromFile(path, blacklist) + if err != nil { + return fmt.Errorf("failed to load IP blacklist: %w", err) + } - if err := blacklistMap.Insert(line); err != nil { - m.logger.Warn("Failed to insert CIDR into trie", zap.String("cidr", line), zap.Error(err)) - } + // Convert the map to CIDRTrie + for ip := range blacklist { + blacklistMap.Insert(ip) } return nil } -func (m *Middleware) loadDNSBlacklistIntoMap(path string, blacklistMap map[string]struct{}) error { - content, err := os.ReadFile(path) - if err != nil { - return fmt.Errorf("failed to read DNS blacklist file: %v", err) +func (m *Middleware) loadDNSBlacklist(path string, blacklistMap map[string]struct{}) error { + if _, err := os.Stat(path); os.IsNotExist(err) { + m.logger.Warn("Skipping DNS blacklist load, file does not exist", zap.String("file", path)) + return nil } - lines := strings.Split(string(content), "\n") - for _, line := range lines { - line = strings.ToLower(strings.TrimSpace(line)) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - blacklistMap[line] = struct{}{} // Changed to struct{}{} + err := m.blacklistLoader.LoadDNSBlacklistFromFile(path, blacklistMap) + if err != nil { + return fmt.Errorf("failed to load DNS blacklist: %w", err) } return nil } @@ -570,3 +541,11 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } return m.configLoader.UnmarshalCaddyfile(d, m) } + +// Validate implements caddy.Validator. +func (m *Middleware) Validate() error { + if m.logLevel == 0 { + m.logLevel = zapcore.InfoLevel // Default log level + } + return nil +} diff --git a/handler.go b/handler.go index 4ee9a35..467a732 100644 --- a/handler.go +++ b/handler.go @@ -11,6 +11,9 @@ import ( "go.uber.org/zap/zapcore" ) +type ContextKeyLogId string +type ContextKeyRule string + // ServeHTTP implements caddyhttp.Handler. func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { logID := uuid.New().String() @@ -41,7 +44,7 @@ func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next cadd err := next.ServeHTTP(recorder, r) // Phase 3: Response Header analysis - if m.isResponseHeaderPhaseBlocked(recorder, r, 3, state) { + if m.isPhaseBlocked(recorder, r, 3, state) { return nil // Request blocked in Phase 3, short-circuit } @@ -83,17 +86,6 @@ func (m *Middleware) isPhaseBlocked(w http.ResponseWriter, r *http.Request, phas return false } -// isResponseHeaderPhaseBlocked encapsulates the response header phase handling and blocking check logic. -func (m *Middleware) isResponseHeaderPhaseBlocked(recorder *responseRecorder, r *http.Request, phase int, state *WAFState) bool { - m.handlePhase(recorder, r, phase, state) - if state.Blocked { - m.incrementBlockedRequestsMetric() - recorder.WriteHeader(state.StatusCode) - return true - } - return false -} - // logRequestStart logs the start of WAF evaluation. func (m *Middleware) logRequestStart(r *http.Request, logID string) { m.logger.Info("WAF request evaluation started", @@ -126,15 +118,20 @@ func (m *Middleware) initializeWAFState() *WAFState { func (m *Middleware) handleResponseBodyPhase(recorder *responseRecorder, r *http.Request, state *WAFState) { // No need to check if recorder.body is nil here, it's always initialized in NewResponseRecorder body := recorder.BodyString() - m.logger.Debug("Response body captured for Phase 4 analysis", zap.String("log_id", r.Context().Value(ContextKeyLogId("logID")).(string))) + logID, ok := r.Context().Value(ContextKeyLogId("logID")).(string) + + if !ok { + m.logger.Error("Log ID not found in context") + return + } + m.logger.Debug("Response body captured for Phase 4 analysis", zap.String("log_id", logID)) for _, rule := range m.Rules[4] { if rule.regex.MatchString(body) { - m.processRuleMatch(recorder, r, &rule, body, state) - if state.Blocked { - m.incrementBlockedRequestsMetric() + if m.processRuleMatch(recorder, r, &rule, body, state) { return } + } } } @@ -191,9 +188,16 @@ func (m *Middleware) copyResponse(w http.ResponseWriter, recorder *responseRecor } w.WriteHeader(recorder.StatusCode()) + logID, ok := r.Context().Value(ContextKeyLogId("logID")).(string) + + if !ok { + m.logger.Error("Log ID not found in context") + return + } + _, err := w.Write(recorder.body.Bytes()) // Copy body from recorder to original writer if err != nil { - m.logger.Error("Failed to write recorded response body to client", zap.Error(err), zap.String("log_id", r.Context().Value("logID").(string))) + m.logger.Error("Failed to write recorded response body to client", zap.Error(err), zap.String("log_id", logID)) } } @@ -221,7 +225,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr, zap.String("message", "Request blocked by country"), ) - m.logger.Debug("Country blocking phase completed - blocked by country") return } m.logger.Debug("Country blocking phase completed - not blocked") @@ -235,7 +238,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule", r.RemoteAddr, zap.String("message", "Request blocked by rate limit"), ) - m.logger.Debug("Rate limiting phase completed - blocked by rate limit") return } m.logger.Debug("Rate limiting phase completed - not blocked") @@ -254,7 +256,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", firstIP, zap.String("message", "Request blocked by IP blacklist"), ) - m.logger.Debug("IP blacklist phase completed - blocked") return } } else { @@ -269,7 +270,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", r.RemoteAddr, zap.String("message", "Request blocked by IP blacklist"), ) - m.logger.Debug("IP blacklist phase completed - blocked") return } } @@ -281,7 +281,6 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i zap.String("message", "Request blocked by DNS blacklist"), zap.String("host", r.Host), ) - m.logger.Debug("DNS blacklist phase completed - blocked") return } @@ -292,6 +291,7 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i } m.logger.Debug("Starting rule evaluation for phase", zap.Int("phase", phase), zap.Int("rule_count", len(rules))) + for _, rule := range rules { m.logger.Debug("Processing rule", zap.String("rule_id", string(rule.ID)), zap.Int("target_count", len(rule.Targets))) @@ -338,16 +338,16 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i ) if phase == 3 || phase == 4 { if recorder, ok := w.(*responseRecorder); ok { - if !m.processRuleMatch(recorder, r, &rule, value, state) { + if m.processRuleMatch(recorder, r, &rule, value, state) { return // Stop processing if the rule match indicates blocking } } else { - if !m.processRuleMatch(w, r, &rule, value, state) { + if m.processRuleMatch(w, r, &rule, value, state) { return // Stop processing if the rule match indicates blocking } } } else { - if !m.processRuleMatch(w, r, &rule, value, state) { + if m.processRuleMatch(w, r, &rule, value, state) { return // Stop processing if the rule match indicates blocking } } diff --git a/logging.go b/logging.go index 908351f..13503ce 100644 --- a/logging.go +++ b/logging.go @@ -36,7 +36,7 @@ var sensitiveKeys = []string{ "routing", // Routing number "mfa", // Multi-factor authentication code "otp", // One-time password - "code", // Generic code + //"code", // Generic code <------ REMOVED THIS } var sensitiveKeysMutex sync.RWMutex // Add mutex for thread safety when modifying diff --git a/request.go b/request.go index f2911ff..e94a220 100644 --- a/request.go +++ b/request.go @@ -18,9 +18,6 @@ type RequestValueExtractor struct { redactSensitiveData bool // Add this field } -// Define a custom type for context keys -type ContextKeyLogId string - // Extraction Target Constants - Improved Readability and Maintainability const ( TargetMethod = "METHOD" diff --git a/response.go b/response.go index 49bf521..50c435c 100644 --- a/response.go +++ b/response.go @@ -29,14 +29,19 @@ func (m *Middleware) blockRequest(w http.ResponseWriter, r *http.Request, state w.WriteHeader(resp.StatusCode) _, err := w.Write([]byte(resp.Body)) if err != nil { - m.logger.Error("Failed to write custom block response body", zap.Error(err), zap.Int("status_code", resp.StatusCode), zap.String("log_id", r.Context().Value("logID").(string))) + logID, ok := r.Context().Value(ContextKeyLogId("logID")).(string) + if !ok { + m.logger.Error("Log ID not found in context, cannot log custom response error") + return + } + m.logger.Error("Failed to write custom block response body", zap.Error(err), zap.Int("status_code", resp.StatusCode), zap.String("log_id", logID)) } return } // Default blocking behavior logID := uuid.New().String() - if logIDCtx, ok := r.Context().Value("logID").(string); ok { + if logIDCtx, ok := r.Context().Value(ContextKeyLogId("logID")).(string); ok { logID = logIDCtx } @@ -70,7 +75,11 @@ func (m *Middleware) blockRequest(w http.ResponseWriter, r *http.Request, state w.WriteHeader(statusCode) } else { // Debug log when response is already written, including log_id - logID, _ := r.Context().Value("logID").(string) + logID, ok := r.Context().Value(ContextKeyLogId("logID")).(string) + if !ok { + m.logger.Error("Log ID not found in context, cannot log blockRequest when response already written") + return + } m.logger.Debug("blockRequest called but response already written", zap.Int("intended_status_code", statusCode), diff --git a/types.go b/types.go index be57c83..8a11f0e 100644 --- a/types.go +++ b/types.go @@ -25,9 +25,6 @@ var ( _ caddyfile.Unmarshaler = (*Middleware)(nil) ) -// Define a custom type for context keys -type ContextKeyRule string - // Define custom types for rule hits type RuleID string type HitCount int