#!/usr/bin/env python3
"""
SMB Traffic Monitor for TrueNAS Core
Monitors SMB traffic on ports 445 and 139 using tcpdump and reports to Graphite
"""

import socket
import time
import argparse
from datetime import datetime
from collections import defaultdict
import threading
import sys
import subprocess
import re

class SMBTrafficMonitor:
    def __init__(self, interface, graphite_host, graphite_port=2003, interval=10, debug=False):
        self.interface = interface
        self.graphite_host = graphite_host
        self.graphite_port = graphite_port
        self.interval = interval
        self.running = False
        self.debug = debug
        
        # Metrics storage
        self.metrics = {
            'bytes_in': 0,
            'bytes_out': 0,
            'packets_in': 0,
            'packets_out': 0,
            'connections': defaultdict(int)
        }
        self.lock = threading.Lock()
        
        # SMB ports
        self.smb_ports = {445, 139}
        
    def get_local_ip(self):
        """Get local IP address of the interface"""
        try:
            result = subprocess.run(
                ['ifconfig', self.interface],
                capture_output=True,
                text=True
            )
            for line in result.stdout.split('\n'):
                if 'inet ' in line:
                    parts = line.strip().split()
                    if len(parts) >= 2:
                        return parts[1]
        except Exception as e:
            print(f"Error getting local IP: {e}")
        
        return "0.0.0.0"
    
    def capture_packets(self):
        """Capture and analyze packets using tcpdump"""
        local_ip = self.get_local_ip()
        print(f"Monitoring SMB traffic on {self.interface} (IP: {local_ip})")
        print(f"Reporting to Graphite at {self.graphite_host}:{self.graphite_port}")
        print(f"Debug mode: {self.debug}")
        
        # tcpdump filter for SMB ports (445 and 139)
        tcpdump_filter = "tcp port 445 or tcp port 139"
        
        # Start tcpdump with verbose mode to get actual packet lengths
        # -v gives us the IP packet length which is what we want
        cmd = [
            'tcpdump',
            '-i', self.interface,
            '-l',  # Line buffered
            '-n',  # Don't resolve hostnames
            '-v',  # Verbose - gives us IP length
            '-tt', # Print timestamp as seconds since epoch
            tcpdump_filter
        ]
        
        try:
            process = subprocess.Popen(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.DEVNULL,
                universal_newlines=True,
                bufsize=1
            )
        except FileNotFoundError:
            print("Error: tcpdump not found. Please install tcpdump.")
            sys.exit(1)
        except PermissionError:
            print("Error: This script requires root privileges to run tcpdump")
            sys.exit(1)
        
        print("tcpdump started successfully...")
        
        # Regex patterns to parse tcpdump verbose output
        # tcpdump -v output spans multiple lines:
        # Line 1: timestamp IP (..., length XXXX)
        # Line 2:     src_ip.src_port > dst_ip.dst_port: ...
        
        # Pattern to capture IP length (total IP packet size including headers)
        ip_length_pattern = re.compile(r'proto TCP \(6\), length (\d+)\)')
        
        # Pattern to capture source and destination (on the indented line)
        addr_pattern = re.compile(
            r'^\s+(\d+\.\d+\.\d+\.\d+)\.(\d+)\s*>\s*'
            r'(\d+\.\d+\.\d+\.\d+)\.(\d+):'
        )
        
        packet_count = 0
        parse_success = 0
        current_ip_length = None
        
        while self.running:
            line = None
            try:
                line = process.stdout.readline()
                if not line:
                    if process.poll() is not None:
                        print("tcpdump process ended unexpectedly")
                        break
                    continue
                
                packet_count += 1
                
                if self.debug and packet_count <= 10:
                    print(f"Raw line {packet_count}: {line.strip()}")
                
                # Check if this line has the IP length (first line of packet)
                ip_match = ip_length_pattern.search(line)
                if ip_match:
                    current_ip_length = int(ip_match.group(1))
                    if self.debug and packet_count <= 10:
                        print(f"  -> Found IP length: {current_ip_length}")
                    continue  # Address will be on next line
                
                # Check if this line has the addresses (second line of packet)
                addr_match = addr_pattern.search(line)
                if addr_match and current_ip_length is not None:
                    src_ip = addr_match.group(1)
                    src_port = int(addr_match.group(2))
                    dst_ip = addr_match.group(3)
                    dst_port = int(addr_match.group(4))
                    
                    # Add Ethernet header (14 bytes) to get total wire size
                    packet_size = current_ip_length + 14
                    
                    parse_success += 1
                    
                    if self.debug and parse_success <= 10:
                        print(f"Parsed packet {parse_success}: {src_ip}:{src_port} -> {dst_ip}:{dst_port} ({packet_size} bytes on wire)")
                    
                    if packet_count % 1000 == 0:
                        print(f"Processed {packet_count} lines, {parse_success} packets parsed successfully")
                
                # Determine if incoming or outgoing
                is_incoming = False
                connection_key = None
                
                if dst_port in self.smb_ports and dst_ip == local_ip:
                    # Incoming to SMB server
                    is_incoming = True
                    connection_key = f"{src_ip}:{src_port}"
                    if self.debug and packet_count <= 10:
                        print(f"  -> SMB INCOMING")
                elif src_port in self.smb_ports and src_ip == local_ip:
                    # Outgoing from SMB server
                    is_incoming = False
                    connection_key = f"{dst_ip}:{dst_port}"
                    if self.debug and packet_count <= 10:
                        print(f"  -> SMB OUTGOING")
                
                if connection_key:
                    with self.lock:
                        if is_incoming:
                            self.metrics['bytes_in'] += packet_size
                            self.metrics['packets_in'] += 1
                        else:
                            self.metrics['bytes_out'] += packet_size
                            self.metrics['packets_out'] += 1
                        
                        self.metrics['connections'][connection_key] += 1
                
            except (AttributeError, ValueError, IndexError) as parse_error:
                if self.debug and line:
                    print(f"Failed to parse line: {parse_error}")
                    print(f"Line was: {line}")
                continue
            except Exception as e:
                if self.running:
                    print(f"Error reading tcpdump output: {e}")
                    if self.debug and line:
                        print(f"Line was: {line}")
                continue
        
        # Cleanup
        process.terminate()
        try:
            process.wait(timeout=5)
        except subprocess.TimeoutExpired:
            process.kill()
    
    def send_to_graphite(self, metric_name, value, timestamp):
        """Send metric to Graphite"""
        message = f"{metric_name} {value} {timestamp}\n"
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.settimeout(5)
            sock.connect((self.graphite_host, self.graphite_port))
            sock.sendall(message.encode())
            sock.close()
            return True
        except Exception as e:
            print(f"Error sending to Graphite: {e}")
            return False
    
    def report_metrics(self):
        """Periodically report metrics to Graphite"""
        base_metric = "truenas.smb"
        
        while self.running:
            time.sleep(self.interval)
            
            timestamp = int(time.time())
            
            with self.lock:
                # Calculate rates (per second)
                bytes_in_rate = self.metrics['bytes_in'] / self.interval
                bytes_out_rate = self.metrics['bytes_out'] / self.interval
                packets_in_rate = self.metrics['packets_in'] / self.interval
                packets_out_rate = self.metrics['packets_out'] / self.interval
                active_connections = len(self.metrics['connections'])
                
                # Send metrics
                self.send_to_graphite(f"{base_metric}.bytes_in", bytes_in_rate, timestamp)
                self.send_to_graphite(f"{base_metric}.bytes_out", bytes_out_rate, timestamp)
                self.send_to_graphite(f"{base_metric}.packets_in", packets_in_rate, timestamp)
                self.send_to_graphite(f"{base_metric}.packets_out", packets_out_rate, timestamp)
                self.send_to_graphite(f"{base_metric}.active_connections", active_connections, timestamp)
                self.send_to_graphite(f"{base_metric}.total_bandwidth", bytes_in_rate + bytes_out_rate, timestamp)
                
                # Console output
                print(f"\n[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Metrics:")
                print(f"  Bytes In:  {bytes_in_rate:.2f} B/s ({bytes_in_rate/1024:.2f} KB/s)")
                print(f"  Bytes Out: {bytes_out_rate:.2f} B/s ({bytes_out_rate/1024:.2f} KB/s)")
                print(f"  Packets In:  {packets_in_rate:.2f} pkt/s")
                print(f"  Packets Out: {packets_out_rate:.2f} pkt/s")
                print(f"  Active Connections: {active_connections}")
                
                # Reset counters
                self.metrics['bytes_in'] = 0
                self.metrics['bytes_out'] = 0
                self.metrics['packets_in'] = 0
                self.metrics['packets_out'] = 0
                self.metrics['connections'].clear()
    
    def start(self):
        """Start monitoring"""
        self.running = True
        
        # Start packet capture thread
        capture_thread = threading.Thread(target=self.capture_packets)
        capture_thread.daemon = True
        capture_thread.start()
        
        # Start reporting thread
        report_thread = threading.Thread(target=self.report_metrics)
        report_thread.daemon = True
        report_thread.start()
        
        try:
            # Keep main thread alive
            while True:
                time.sleep(1)
        except KeyboardInterrupt:
            print("\nStopping monitor...")
            self.running = False
            time.sleep(2)

def main():
    parser = argparse.ArgumentParser(description='Monitor SMB traffic and report to Graphite')
    parser.add_argument('-i', '--interface', required=True, help='Network interface to monitor (e.g., em0, igb0, vmx0)')
    parser.add_argument('-g', '--graphite-host', required=True, help='Graphite server hostname or IP')
    parser.add_argument('-p', '--graphite-port', type=int, default=2003, help='Graphite port (default: 2003)')
    parser.add_argument('-t', '--interval', type=int, default=10, help='Reporting interval in seconds (default: 10)')
    parser.add_argument('-d', '--debug', action='store_true', help='Enable debug output')
    
    args = parser.parse_args()
    
    monitor = SMBTrafficMonitor(
        interface=args.interface,
        graphite_host=args.graphite_host,
        graphite_port=args.graphite_port,
        interval=args.interval,
        debug=args.debug
    )
    
    monitor.start()

if __name__ == '__main__':
    main()