
"""
Gravitational-orbital-modeling
Newton-vs-Orbital_Newton.py

Newtonian 3 body orbit for comparison
Copyright 2025 Malcolm J. Macleod (malcolm@codingthecosmos.com)

* This software is free to use, modify, and distribute under creative commons 
 * https://creativecommons.org/licenses/by-nc-sa/4.0/ (CC CC BY-NC-SA 4.0). 
 * Any derivative works or redistributions must include appropriate credit to 
 * Malcolm Macleod and this permission notice shall
 * be included in all copies or substantial portions of the Software.
 
"""


import numpy as np
import matplotlib.pyplot as plt

class Body:
    def __init__(self, x, y, vx, vy):
        self.position = np.array([x, y], dtype=np.float64)
        self.velocity = np.array([vx, vy], dtype=np.float64)

class ThreeBodySystem:
    def __init__(self, bodies, G):
        if len(bodies) != 3:
            raise ValueError("ThreeBodySystem requires exactly 3 Body objects")
        self.bodies = bodies
        self.G = G
        self.min_distance = 1e-10

    def calculate_forces(self):
        forces = [np.zeros(2) for _ in range(3)]
        for i in range(3):
            for j in range(i+1, 3):
                r_vec = self.bodies[j].position - self.bodies[i].position
                r_mag = max(np.linalg.norm(r_vec), self.min_distance)
                force_mag = self.G / r_mag**2
                force_dir = r_vec / r_mag
                forces[i] += force_mag * force_dir
                forces[j] -= force_mag * force_dir
        return forces

    def calculate_energy(self):
        kinetic = 0.5 * sum(np.sum(body.velocity**2) for body in self.bodies)
        potential = 0.0
        for i in range(3):
            for j in range(i+1, 3):
                r = max(np.linalg.norm(self.bodies[j].position - self.bodies[i].position), self.min_distance)
                potential -= self.G / r
        return kinetic + potential

    def calculate_angular_momentum(self):
        return sum(body.position[0]*body.velocity[1] - body.position[1]*body.velocity[0] for body in self.bodies)

    def update_leapfrog(self, dt=1.0):
        forces = self.calculate_forces()
        for i, body in enumerate(self.bodies):
            body.velocity += 0.5 * forces[i] * dt
        for body in self.bodies:
            body.position += body.velocity * dt
        forces = self.calculate_forces()
        for i, body in enumerate(self.bodies):
            body.velocity += 0.5 * forces[i] * dt
        return self.get_positions()

    def get_positions(self):
        return np.concatenate([body.position for body in self.bodies])

def load_data_and_initialize_bodies(filename, velocity_method='polynomial'):
    """Load data with improved velocity calculation using polynomial fitting."""
    # Load sufficient initial points for velocity estimation
    data_initial = np.loadtxt(filename, delimiter=' ', usecols=range(1,7), max_rows=20)
    if len(data_initial) < 2:
        raise ValueError("Data file must contain at least two rows for velocity calculation.")
    
    # Load full orbital data
    data_orbital = np.loadtxt(filename, delimiter=' ', usecols=range(1,7))
    
    # Calculate initial velocities using polynomial fitting
    vels = calculate_initial_velocities_polynomial(data_initial)
    
    bodies = [
        Body(data_initial[0, 0], data_initial[0, 1], vels[0], vels[1]),
        Body(data_initial[0, 2], data_initial[0, 3], vels[2], vels[3]),
        Body(data_initial[0, 4], data_initial[0, 5], vels[4], vels[5])
    ]
    return bodies, data_orbital

def calculate_initial_velocities_polynomial(data, poly_deg=2, num_points=10):
    """Calculate initial velocities using polynomial regression on position data."""
    num_points = min(num_points, len(data))
    t = np.arange(num_points)  # Time steps (0, 1, 2, ...)
    vels = np.zeros(data.shape[1])
    
    for col in range(data.shape[1]):
        y = data[:num_points, col]
        
        # Fit polynomial to position data
        coeffs = np.polyfit(t, y, poly_deg)
        
        # Get derivative coefficients (velocity polynomial)
        deriv_coeffs = np.polyder(coeffs)
        
        # Evaluate derivative at first time point (t=0)
        vels[col] = np.polyval(deriv_coeffs, 0)
    
    return vels

def simulate_for_G(system, steps):
    positions = np.zeros((steps, 6))
    positions[0] = system.get_positions()
    for i in range(1, steps):
        positions[i] = system.update_leapfrog(1.0)
    return positions

def generate_newtonian_data(system, n_iterations):
    positions = np.zeros((n_iterations, 6))
    energy = np.zeros(n_iterations)
    angular_momentum = np.zeros(n_iterations)
    positions[0] = system.get_positions()
    energy[0] = system.calculate_energy()
    angular_momentum[0] = system.calculate_angular_momentum()
    for i in range(1, n_iterations):
        positions[i] = system.update_leapfrog(1.0)
        energy[i] = system.calculate_energy()
        angular_momentum[i] = system.calculate_angular_momentum()
    #np.savetxt("newtonian_results.csv", 
    #          np.hstack([np.arange(1, n_iterations+1)[:, None], positions]),
    #          delimiter=',', header="time,x1,y1,x2,y2,x3,y3",
    #          comments='', fmt=['%d'] + ['%.8f']*6)
    return positions, energy, angular_momentum

def calculate_orbital_metrics(orbital_data, G):
    n_steps = len(orbital_data)
    energy = np.zeros(n_steps)
    angular_momentum = np.zeros(n_steps)
    for i in range(n_steps):
        pos = orbital_data[i]
        x1, y1, x2, y2, x3, y3 = pos
        if i < n_steps - 1:
            vel = orbital_data[i+1] - pos
        else:
            vel = pos - orbital_data[i-1]
        vx1, vy1, vx2, vy2, vx3, vy3 = vel
        kinetic = 0.5 * (vx1**2 + vy1**2 + vx2**2 + vy2**2 + vx3**2 + vy3**2)
        r12 = np.hypot(x2-x1, y2-y1)
        r13 = np.hypot(x3-x1, y3-y1)
        r23 = np.hypot(x3-x2, y3-y2)
        potential = -G * (1.0/r12 + 1.0/r13 + 1.0/r23)
        energy[i] = kinetic + potential
        am = (x1*vy1 - y1*vx1) + (x2*vy2 - y2*vx2) + (x3*vy3 - y3*vx3)
        angular_momentum[i] = am
    return energy, angular_momentum

def plot_symmetry_check(orbital_data, newton_data, indices):
    plt.figure(figsize=(12, 6))
    # Orbital data symmetry
    orb_x_diff = orbital_data[indices, 2] - orbital_data[indices, 4]
    orb_y_sum = orbital_data[indices, 3] + orbital_data[indices, 5]
    # Newtonian data symmetry
    new_x_diff = newton_data[indices, 2] - newton_data[indices, 4]
    new_y_sum = newton_data[indices, 3] + newton_data[indices, 5]

    plt.plot(indices+1, orb_x_diff, 'b-', lw=1, label='Orbital x2−x3')
    plt.plot(indices+1, orb_y_sum, 'g-', lw=1, label='Orbital y2+y3')
    plt.plot(indices+1, new_x_diff, 'r--', lw=1, label='Newtonian x2−x3')
    plt.plot(indices+1, new_y_sum, 'm--', lw=1, label='Newtonian y2+y3')

    plt.title('Symmetry Check: x2−x3 and y2+y3 Deviations', fontsize=14)
    plt.xlabel('Iteration', fontsize=12)
    plt.ylabel('Deviation Magnitude', fontsize=12)
    plt.legend(loc='upper right', fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('symmetry_check.png')
    plt.close()

def calculate_cumulative_distance(data):
    steps = len(data)
    cum_distances = np.zeros((steps, 3))  # [m1, m2, m3]
    for i in range(1, steps):
        for b in range(3):
            prev_pos = data[i-1, 2*b : 2*b+2]
            curr_pos = data[i, 2*b : 2*b+2]
            cum_distances[i, b] = cum_distances[i-1, b] + np.linalg.norm(curr_pos - prev_pos)
    return cum_distances

def plot_cumulative_distances(cum_data, title, filename):
    plt.figure(figsize=(12, 6))
    labels = ['Body 1', 'Body 2', 'Body 3']
    colors = ['tab:blue', 'tab:orange', 'tab:green']
    for b in range(3):
        plt.plot(cum_data[:, b], label=labels[b], color=colors[b], alpha=0.8)
    plt.title(title, fontsize=14)
    plt.xlabel('Iteration', fontsize=12)
    plt.ylabel('Total Distance Traveled', fontsize=12)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()

def plot_trajectory_comparison(orbital_data, newton_data, indices):
    plt.figure(figsize=(12, 12))
    plt.plot(orbital_data[indices, 0], orbital_data[indices, 1], 'b-', label='Orbital Body 1', alpha=0.7)
    plt.plot(orbital_data[indices, 2], orbital_data[indices, 3], 'g-', label='Orbital Body 2', alpha=0.7)
    #plt.plot(orbital_data[indices, 4], orbital_data[indices, 5], 'k-', label='Orbital Body 3', alpha=0.7)
    plt.plot(newton_data[indices, 0], newton_data[indices, 1], 'r--', label='Newtonian Body 1')
    plt.plot(newton_data[indices, 2], newton_data[indices, 3], 'm--', label='Newtonian Body 2')
    #plt.plot(newton_data[indices, 4], newton_data[indices, 5], 'c--', label='Newtonian Body 3')
    plt.xlabel('X Position')
    plt.ylabel('Y Position')
    plt.title('Trajectory Comparison')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('trajectory_comparison.png', bbox_inches='tight')
    plt.close()

def plot_distance_comparison(orbital_data, newton_data, indices):
    plt.figure(figsize=(14, 7))
    orb_m1m2 = np.hypot(orbital_data[indices,2] - orbital_data[indices,0],
                        orbital_data[indices,3] - orbital_data[indices,1])
    orb_m1m3 = np.hypot(orbital_data[indices,4] - orbital_data[indices,0],
                        orbital_data[indices,5] - orbital_data[indices,1])
    new_m1m2 = np.hypot(newton_data[indices,2] - newton_data[indices,0],
                        newton_data[indices,3] - newton_data[indices,1])
    new_m1m3 = np.hypot(newton_data[indices,4] - newton_data[indices,0],
                        newton_data[indices,5] - newton_data[indices,1])
    plt.plot(indices+1, orb_m1m2, 'b-', lw=1, label='Orbital m1-m2')
    #plt.plot(indices+1, orb_m1m3, 'g-', lw=1, label='Orbital m1-m3')
    plt.plot(indices+1, new_m1m2, 'r--', lw=1, label='Newtonian m1-m2')
    #plt.plot(indices+1, new_m1m3, 'm--', lw=1, label='Newtonian m1-m3')
    plt.title('Inter-body Distance Comparison', fontsize=14)
    plt.xlabel('Iteration', fontsize=12)
    plt.ylabel('Distance', fontsize=12)
    plt.legend(loc='upper right', fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('distance_comparison.png')
    plt.close()

def plot_barycenter_distances_m2(orbital_data, newton_data, indices):
    plt.figure(figsize=(14, 7))
    
    # Calculate barycenter for orbital data
    orb_bary_x = (orbital_data[indices, 0] + orbital_data[indices, 2] + orbital_data[indices, 4]) / 3
    orb_bary_y = (orbital_data[indices, 1] + orbital_data[indices, 3] + orbital_data[indices, 5]) / 3
    
    # Distances for m2 and m3 from barycenter (orbital)
    orb_m2_dist = np.hypot(orbital_data[indices, 2] - orb_bary_x, orbital_data[indices, 3] - orb_bary_y)
    orb_m3_dist = np.hypot(orbital_data[indices, 4] - orb_bary_x, orbital_data[indices, 5] - orb_bary_y)
    
    # Calculate barycenter for Newtonian data
    new_bary_x = (newton_data[indices, 0] + newton_data[indices, 2] + newton_data[indices, 4]) / 3
    new_bary_y = (newton_data[indices, 1] + newton_data[indices, 3] + newton_data[indices, 5]) / 3
    
    # Distances for m2 and m3 from barycenter (Newtonian)
    new_m2_dist = np.hypot(newton_data[indices, 2] - new_bary_x, newton_data[indices, 3] - new_bary_y)
    new_m3_dist = np.hypot(newton_data[indices, 4] - new_bary_x, newton_data[indices, 5] - new_bary_y)
    
    # Plotting
    plt.plot(indices+1, orb_m2_dist, 'b-', lw=1, label='Orbital m2-Barycenter')
    #plt.plot(indices+1, orb_m3_dist, 'g-', lw=1, label='Orbital m3-Barycenter')
    plt.plot(indices+1, new_m2_dist, 'r--', lw=1, label='Newtonian m2-Barycenter')
    #plt.plot(indices+1, new_m3_dist, 'm--', lw=1, label='Newtonian m3-Barycenter')
    
    plt.title('Distance of m2 from Barycenter Over Time', fontsize=14)
    plt.xlabel('Iteration', fontsize=12)
    plt.ylabel('Distance from Barycenter', fontsize=12)
    plt.legend(loc='upper right', fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('barycenter_distance_comparison_m2.png')
    plt.close()
    
    
def plot_barycenter_distances_m3(orbital_data, newton_data, indices):
    plt.figure(figsize=(14, 7))
    
    # Calculate barycenter for orbital data
    orb_bary_x = (orbital_data[indices, 0] + orbital_data[indices, 2] + orbital_data[indices, 4]) / 3
    orb_bary_y = (orbital_data[indices, 1] + orbital_data[indices, 3] + orbital_data[indices, 5]) / 3
    
    # Distances for m2 and m3 from barycenter (orbital)
    orb_m2_dist = np.hypot(orbital_data[indices, 2] - orb_bary_x, orbital_data[indices, 3] - orb_bary_y)
    orb_m3_dist = np.hypot(orbital_data[indices, 4] - orb_bary_x, orbital_data[indices, 5] - orb_bary_y)
    
    # Calculate barycenter for Newtonian data
    new_bary_x = (newton_data[indices, 0] + newton_data[indices, 2] + newton_data[indices, 4]) / 3
    new_bary_y = (newton_data[indices, 1] + newton_data[indices, 3] + newton_data[indices, 5]) / 3
    
    # Distances for m2 and m3 from barycenter (Newtonian)
    new_m2_dist = np.hypot(newton_data[indices, 2] - new_bary_x, newton_data[indices, 3] - new_bary_y)
    new_m3_dist = np.hypot(newton_data[indices, 4] - new_bary_x, newton_data[indices, 5] - new_bary_y)
    
    # Plotting
    #plt.plot(indices+1, orb_m2_dist, 'b-', lw=1, label='Orbital m2-Barycenter')
    plt.plot(indices+1, orb_m3_dist, 'g-', lw=1, label='Orbital m3-Barycenter')
    #plt.plot(indices+1, new_m2_dist, 'r--', lw=1, label='Newtonian m2-Barycenter')
    plt.plot(indices+1, new_m3_dist, 'm--', lw=1, label='Newtonian m3-Barycenter')
    
    plt.title('Distance of m3 from Barycenter Over Time', fontsize=14)
    plt.xlabel('Iteration', fontsize=12)
    plt.ylabel('Distance from Barycenter', fontsize=12)
    plt.legend(loc='upper right', fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('barycenter_distance_comparison_m3.png')
    plt.close()

    
    
def main():
    #Generate file in Newton-vs-Orbital_Orbital.py
    FILENAME = "data-3b-long.txt"
    TARGET_ITERATION = 1118213      



    bodies, orbital_data = load_data_and_initialize_bodies(FILENAME)
    initial_pos = orbital_data[0, :2]
    
    """
    print("Performing coarse G optimization...")
    #G_coarse = np.linspace(0.1, 1.9, 200)
    G_coarse = np.linspace(0.47, 0.87, 200)
    best_G_coarse = None
    min_error_coarse = float('inf')
    for G in G_coarse:
        sys = ThreeBodySystem([Body(b.position[0], b.position[1], b.velocity[0], b.velocity[1]) for b in bodies], G)
        sim = simulate_for_G(sys, TARGET_ITERATION)
        final_pos = sim[-1, :2]
        error = np.linalg.norm(final_pos - initial_pos)
        if error < min_error_coarse:
            min_error_coarse = error
            best_G_coarse = G
            print(f"Coarse Optimal G: {best_G_coarse:.6f}, Error: {min_error_coarse:.6f}")
    print(f"Coarse Optimal G: {best_G_coarse:.6f}, Error: {min_error_coarse:.6f}")  

    print("Performing high-precision G optimization...")
    delta = 0.01
    G_fine = np.linspace(best_G_coarse - delta, best_G_coarse + delta, 40)
    best_G_fine = None
    min_error_fine = float('inf')
    for G in G_fine:
        sys = ThreeBodySystem([Body(b.position[0], b.position[1], b.velocity[0], b.velocity[1]) for b in bodies], G)
        sim = simulate_for_G(sys, TARGET_ITERATION)
        final_pos = sim[-1, :2]
        error = np.linalg.norm(final_pos - initial_pos)
        if error < min_error_fine:
            min_error_fine = error
            best_G_fine = G
    print(f"High-precision Optimal G: {best_G_fine:.8f}, Error: {min_error_fine:.6f}")
    """
    
    best_G_fine = 0.497     
    
    optimal_bodies = [Body(b.position[0], b.position[1], b.velocity[0], b.velocity[1]) for b in bodies]
    sys = ThreeBodySystem(optimal_bodies, best_G_fine)
    newton_positions, newton_energy, newton_am = generate_newtonian_data(sys, TARGET_ITERATION)

    orbital_energy, orbital_am = calculate_orbital_metrics(orbital_data, best_G_fine)

    sample_rate = max(1, TARGET_ITERATION // 1000)
    indices = np.arange(0, TARGET_ITERATION, sample_rate)

    plot_trajectory_comparison(orbital_data, newton_positions, indices)
    plot_distance_comparison(orbital_data, newton_positions, indices)
    plot_symmetry_check(orbital_data, newton_positions, indices)

    orb_cum = calculate_cumulative_distance(orbital_data)
    new_cum = calculate_cumulative_distance(newton_positions)
    plot_cumulative_distances(orb_cum, "Orbital Data - Cumulative Distances", "orbital_cum_distances.png")
    plot_cumulative_distances(new_cum, "Newtonian Simulation - Cumulative Distances", "newton_cum_distances.png")
    plot_barycenter_distances_m2(orbital_data, newton_positions, indices)
    plot_barycenter_distances_m3(orbital_data, newton_positions, indices)



if __name__ == "__main__":
    main()