import serial
import struct
import numpy as np
import math
import time

# --- CONFIGURATION ---
PORT = '/dev/rfcomm1' #/dev/ttyACM2'
BAUD = 115200
SAMPLE_RATE = 16000     # Adjust based on your specific mic hardware
MIN_FREQ = 1500          # Low cutoff (Hz)
MAX_FREQ = 2500          # High cutoff (Hz)
BASE_SPEED = 300        # Forward speed (0 to 1000)
MAX_SPEED = 1000        # Safety limit
MAGNITUDE_THRESHOLD = 10000  # Minimum FFT magnitude to consider valid

# --- SERIAL SETUP ---
# timeout=0 is risky for reading blocks, using a small timeout is safer
ser = serial.Serial(PORT, BAUD, timeout=0.1)

def get_mic_data():
    """
    Fetches 160 samples from 4 mics, fixes overflow, and de-interleaves.
    Returns: (160, 4) numpy array or None on failure.
    """
    # Flush junk
    ser.reset_input_buffer()
    
    # 1. Request Data (Command: -'U', 0)
    ser.write(struct.pack(">bb", -ord('U'), 0))
    
    # 2. Read 1280 bytes (160 samples * 4 channels * 2 bytes)
    expected_bytes = 1280
    data = b""
    start_time = time.time()
    
    while len(data) < expected_bytes:
        if time.time() - start_time > 0.5: # 500ms timeout
            return None
        chunk = ser.read(expected_bytes - len(data))
        if chunk:
            data += chunk
            
    if len(data) != expected_bytes:
        return None

    # 3. Unpack & Cast (Direct casting as requested)
    # '>i2' = Big-Endian 16-bit signed
    raw = np.frombuffer(data, dtype='<i2').astype(np.int32)
    mics = raw.reshape(-1, 4) # De-interleave: Rows=Time, Cols=Ch (R, L, B, F)
        
    return mics

def analyze_sound_max_peak(mics):
    """
    Returns (angle_degrees, is_valid)
    """
    # 1. FFT to find dominant frequency
    fft_res = np.fft.rfft(mics, axis=0)
    freqs = np.fft.rfftfreq(len(mics), d=1/SAMPLE_RATE)
    magnitudes = np.abs(fft_res)
    #print("FFT magnitudes:", magnitudes)
    #print("Frequencies:", freqs)
    
    # Sum magnitudes across all mics to find the global peak
    total_mag = np.sum(magnitudes, axis=1)
    peak_idx = np.argmax(total_mag)
    peak_freq = freqs[peak_idx]
    print("Peak Frequency:", peak_freq)
    #print("freq bin magnitudes:", magnitudes[peak_idx])
    print("peak magnitude:", total_mag[peak_idx])
    
    # 2. Filter by Frequency
    if not (MIN_FREQ <= peak_freq <= MAX_FREQ) or (total_mag[peak_idx] < MAGNITUDE_THRESHOLD):
        return 0, False

    # 3. Calculate Angle using Phase Difference at the peak bin
    # Extract phases for the dominant frequency
    phases = np.angle(fft_res[peak_idx])
    
    # Indices: 0:Right, 1:Left, 2:Back, 3:Front
    # We create phasors (complex numbers) to subtract phases safely
    # Left-Right axis (Y component)
    dy_phasor = np.exp(1j * phases[1]) / np.exp(1j * phases[0]) 
    # Front-Back axis (X component)
    dx_phasor = np.exp(1j * phases[3]) / np.exp(1j * phases[2])
    
    y_strength = np.angle(dy_phasor)
    x_strength = np.angle(dx_phasor)
    
    # Calculate angle (0 degrees is straight ahead/Front)
    angle_rad = math.atan2(y_strength, x_strength)
    return math.degrees(angle_rad), True

def send_speed(left, right):
    """
    Format: [-D][LEFT_LSB][LEFT_MSB][RIGHT_LSB][RIGHT_MSB]
    Range: -1000 to 1000
    """
    # Clamp values
    left = max(min(int(left), 1000), -1000)
    right = max(min(int(right), 1000), -1000)
    
    # Build packet
    # 'b' = signed char (1 byte)
    # '<h' = little-endian short (2 bytes)
    header = struct.pack('b', -ord('D'))
    payload = struct.pack('<hh', left, right)
    
    ser.write(header + payload)

# --- MAIN LOOP ---
print("Starting Robot Listener...")
try:
    while True:
        mics = get_mic_data()
        
        if mics is not None:
            angle, valid = analyze_sound_max_peak(mics)
            #angle, valid = analyze_sound_multi_peaks(mics)
            #angle, valid = analyze_sound_hps(mics)
            
            if valid:
                print(f"Target at {angle:.1f}°")
                
                # --- STEERING LOGIC ---
                # Simple Proportional controller
                # If angle is +90 (Left), we want: Right > Left
                
                turn_factor = angle * 2.5  # Gain (Tune this!)
                
                left_speed = BASE_SPEED - turn_factor
                right_speed = BASE_SPEED + turn_factor
                
                send_speed(left_speed, right_speed)
            else:
                # Stop or scan if frequency matches nothing
                send_speed(0, 0)
        
        # Small sleep to prevent serial bus choking
        time.sleep(0.05) 

except KeyboardInterrupt:
    print("Stopping.")
    send_speed(0, 0)
    ser.close()