import math
import time
from ulab import numpy as np # Requires ulab firmware
import epuck2
import gc

SAMPLE_RATE = 16000
MIN_FREQ = 1500
MAX_FREQ = 2500
BASE_SPEED = 300
MAGNITUDE_THRESHOLD = 10000 
MAX_SPEED = 1000
MIN_SPEED = -1000

# --- FFT settings ---
FFT_LEN = 128        # power of 2 for ulab
FFT_HALF_LEN = FFT_LEN // 2 # integer division for half length of FFT result
NUM_MICS = 4

# Preallocate arrays to save memory
gc.collect()
#print("Memory befrore allocations:", gc.mem_free())
total_mag = np.zeros(FFT_HALF_LEN)
phases = np.zeros(NUM_MICS)
phase_re = np.zeros((4, FFT_HALF_LEN))
phase_im = np.zeros((4, FFT_HALF_LEN))
mic_buffer = bytearray(1280)
#print("Memory after allocations:", gc.mem_free())

def get_mic_data():
    global mic_buffer
    """
    Fetches 160 samples from 4 mics.
    Returns: (160, 4) ulab array or None on failure.
    """
    success = epuck2.get_mic_data(mic_buffer)
    
    if not success:
        print("Failed to acquire mic data.")
        return None

    # 3. Unpack & Cast
    # ulab np.frombuffer works similarly to numpy
    # Note: ulab might default to big-endian or system endianness depending on build
    # We manually unpack if ulab frombuffer specific type isn't available, 
    # but standard ulab supports int16.
    
    # Create array from buffer (interpreted as int16)
    # Note: ulab's frombuffer may not support the '<i2' string format directly like numpy.
    # It usually takes dtype=np.int16. ESP32 is little-endian, so this usually matches '<'.
    #raw = np.frombuffer(mic_buffer, dtype=np.int16)    
    
    # Reshape: 160 rows, 4 columns => each column is a mic channel (Right, Left, Back, Front)
    #mics = raw.reshape((160, 4))
        
    #return mics

    return np.frombuffer(mic_buffer, dtype=np.int16).reshape((160, 4))

def analyze_sound_max_peak(mics):
    global total_mag, phases, phase_re, phase_im
    """
    Memory-efficient sound direction estimation for 4 mics.
    Returns: (angle_degrees, is_valid)
    """
    # --- Reset total magnitude array ---
    total_mag[:] = 0 # Reset total magnitude array

    # --- Compute FFT per channel ---
    for ch in range(NUM_MICS):
        gc.collect()
        # FFT
        #print("Memory info before FFT:", gc.mem_free())
        fft_result = np.fft.fft(mics[:FFT_LEN, ch])

        # Sum magnitudes on the fly (only positive frequencies)
        phase_re[ch, :] = fft_result.real[:FFT_HALF_LEN]
        phase_im[ch, :] = fft_result.imag[:FFT_HALF_LEN]
        #total_mag += np.sqrt(phase_re[ch]**2 + phase_im[ch]**2)
        mag_sq = (fft_result.real[:FFT_HALF_LEN] * fft_result.real[:FFT_HALF_LEN]) + (fft_result.imag[:FFT_HALF_LEN] * fft_result.imag[:FFT_HALF_LEN])
        total_mag += np.sqrt(mag_sq)

    # --- Find peak frequency ---
    peak_idx = np.argmax(total_mag)
    peak_mag = total_mag[peak_idx]

    peak_freq = peak_idx * (SAMPLE_RATE / FFT_LEN)
    print("Freq:", peak_freq, "Mag:", peak_mag)

    # Reject if frequency out of range or too weak
    if not (MIN_FREQ <= peak_freq <= MAX_FREQ) or peak_mag < MAGNITUDE_THRESHOLD:
        return 0, False

    # --- Compute phases at peak ---
    for ch in range(NUM_MICS):
        re = phase_re[ch, peak_idx]
        im = phase_im[ch, peak_idx]
        phases[ch] = math.atan2(im, re)

    # Mic layout: 0=Right, 1=Left, 2=Back, 3=Front
    dy_angle = phases[1] - phases[0]  # Left - Right
    dx_angle = phases[3] - phases[2]  # Front - Back

    # Wrap to [-pi, pi]
    dy_angle = (dy_angle + math.pi) % (2 * math.pi) - math.pi
    dx_angle = (dx_angle + math.pi) % (2 * math.pi) - math.pi

    # Compute angle in degrees
    angle_rad = math.atan2(dy_angle, dx_angle)
    return math.degrees(angle_rad), True



# --- MAIN LOOP ---
print("Starting ESP32 Robot Listener...")

try:
    while True:
        mics = get_mic_data()
        
        if mics is not None:
            # We catch errors here to prevent crashing on math anomalies
            try:
                angle, valid = analyze_sound_max_peak(mics)
                
                if valid:
                    print("Target: {:.1f}".format(angle))
                    
                    turn_factor = angle * 2.5
                    
                    left_speed = int(BASE_SPEED - turn_factor)
                    right_speed = int(BASE_SPEED + turn_factor)
                    left_speed = max(MIN_SPEED, min(MAX_SPEED, left_speed))
                    right_speed = max(MIN_SPEED, min(MAX_SPEED, right_speed))
                    
                    epuck2.set_motors_speed(left_speed, right_speed)
                else:
                    epuck2.set_motors_speed(0, 0)
            except Exception as e:
                print("Error:", e)
                gc.collect()
                
        time.sleep(0.05)

except KeyboardInterrupt:
    print("Stopping.")
    epuck2.set_motors_speed(0, 0)