Skip to content

Calculation of dipole-dipole time-delayed correlation function

1 Introduction

This tutorial is used to give a guideline to calculate dipole-dipole time-delayed correlation function \(C(r,\tau)\), which is defined by

\[C(r,\tau)=\dfrac{1}{N-\tau+1}\sum_{t=0}^{N-\tau+1}\dfrac{1}{N_p(r)}\sum_i\sum_j \dfrac{\bm{P}_i(t)}{\|\bm{P}_i(t)\|} \dfrac{\bm{P}_j(t+\tau)}{\|\bm{P}_j(t+\tau)\|} , \|\bm{r_{ij}}\|=r \]

where \(\tau\) is the time delay, \(N\) represents the MD simulation time, \(\bm{P}_i(t)\) denotes the instantaneous polarization of unit cell \(i\) at timestep \(t\), and \(N_p(r)\) counts the number of pairs where unit cells \(i\) and \(j\) are separated by distance \(\|\bm{r}_{ij}\| = r\). For simplicity, the distance between two nearest-neighbor unit cells is defined as \(1\).

Physically, \(C(r,\tau)\) quantifies the spatiotemporal correlation between arbitrary pairs of unit cells (or dipoles) separated by distance \(r\) with time delay \(\tau\). The average angular deviation between such pairs can be obtained via \(\arccos{C(r,\tau)}\).

2 Data preparation

Before start to calculate the correlation function, make sure you have already finished the polarization calculation and saved as .npy format with \((T,N,3)\) shape, here \(N\) is the number of unitcells and \(T\) is the total timesteps of MD trajectory, 3 indicates the polarization along three axis. Here is a quick tutorial to calculate the polarization:

Important: the following steps are only used for perovskite structures !!!

Given a configuration from MD simulations, the polarization for ABO3 perovskite systems can be estimated using the following formula,

\[\mathbf{P}^m(t)=\frac{1}{V_{\rm uc}}\left[\frac{1}{8} \mathbf{Z}_{A}^* \sum_{i=1}^8 \mathbf{r}_{A, i}^m(t)+\mathbf{Z}_{B}^* \mathbf{r}_{B,i}^m(t)+\frac{1}{2} \mathbf{Z}_{\mathrm{O}}^* \sum_{i=1}^6 \mathbf{r}_{\mathrm{O}, i}^m(t)\right] \]

where \(\mathbf{P}^m(t)\) is the polarization of unit cell \(m\) at time \(t\), \(V_{\rm uc}\) is the volume of the unit cell, \(\mathbf{Z}_{A}^*, \mathbf{Z}_{B}^*\), and \(\mathbf{Z}_{\mathrm{O}}^*\) are the average Born effective charges of A site, B site and O atoms, \(\mathbf{r}_{A, i}^m(t), \mathbf{r}_{B, i}^m(t)\), and \(\mathbf{r}_{\mathrm{O}, i}^m(t)\) are the instantaneous atomic positions.

To calculate the polarization on our group cluster, you should load the environment path:

export PYTHONPATH=/shared_storage/home/share/code_ldn/fdc/

Taking Pb(In½Nb½)O3-Pb(MgNb)O3-PbTiO3 (PIN-PMN-PT) as an example, which has 4 different element on B site. We use the following script get_neighbor.py to calculate the neighborlist for each A and B site atom.

#usage: python get_neighbor.py <trajectory file>
from ferrodispcalc import NeighborList
from ferrodispcalc.type_map import UniPero
import sys

#Center atom: B site, neighbour atom: A site
file_name = sys.argv[1]
nl_ba = NeighborList(file_name, format='lmp-dump', type_map=UniPero)
nl_ba.build(
    center_elements=['Ti','Mg','Nb','In'],
    neighbor_elements=['Pb'],
    neighbor_num=8,
    cutoff=5,
    defect=False
)
nl_ba.nl -= 1
nl_ba.write('BA.dat')

#Center atom: O, neighbour atom: B site
nl_bo = NeighborList(file_name, format='lmp-dump', type_map=UniPero)
nl_bo.build(
    center_elements=['Ti','Mg','Nb','In'],
    neighbor_elements=['O'],
    neighbor_num=6,
    cutoff=5,
    defect=False
)
nl_bo.nl -= 1
nl_bo.write('BO.dat')

#Center atom: O site, neighbour atom: A site
nl_ao = NeighborList(file_name, format='lmp-dump', type_map=UniPero)
nl_ao.build(
    center_elements=['Pb'],
    neighbor_elements=['O'],
    neighbor_num=12,
    cutoff=5,
    defect=False
)
nl_ao.nl -= 1
nl_ao.write('AO.dat')

For different ABO3 strcutures, you have to change the center_elements and neighbor_elements list on the script above. Moreover, for different configurations with the same ABO3 formula, you also have to calculate the neighborlist separately.

Sometimes there might be an error:

ValueError: 4126 Ti has 7 neighbors, expected at least 8

which means the value of cutoff you set is too small, try to use bigger cutoff value to calculate the neighbour list, recommended value: 5 ~ 6.

Once we get the neighborlist AO.dat, BO.dat and BA.dat, it is simple to calculate the polarization:

get_p <trajectory file> <output file> BA.dat BO.dat type_map_file bec_file <total frames>
here type_map_file is the typemap for your MD potential, for UniPero (version 2025.3), this typemap is

Ba,Pb,Ca,Sr,Bi,K,Na,Hf,Ti,Zr,Nb,Mg,In,Zn,O

the bec_file includes the born effective charge for each element, and it has the same sequence as type_map_file, you can obtain these values via papers or DFT calculations. Here is an example for PIN-PMN-PT solid solution:

0,3.502,0,0,0,0,2.848,0,5.434,0,5.434,5.434,5.434,0,-2.978

For further calculation, use the following script polar2npy.py to transform the polarization file to .npy format.

#usage: python polar2npy.py <polarization_file> <output_file (.npy format)> <total_unitcell_number>
import numpy as np
import sys 

input_name = sys.argv[1]
output_name = sys.argv[2]
natom = int(sys.argv[3])

data = np.loadtxt(input_name)
ndata = data.shape[0]
nframe = ndata // natom
if ndata % natom != 0:
    print("Error: the number of data is not a multiple of natom")
    sys.exit(1)

data = data.reshape(nframe, natom, 3)
np.save(output_name, data)

And I also highly recommend to calculate the time-dependent polarization via the following script pt.py, which is really helpful for other calculations.

#usage: python pt.py <polarization_file> <output_file> <total_unitcell_number>
import numpy as np
import sys

input_file= sys.argv[1]
output_file = sys.argv[2]
frame_size = int(sys.argv[3])

def main():
    data = np.loadtxt(input_file)  
    nrows = data.shape[0]
    nframes = nrows // frame_size  

    with open(output_file, "w") as f_out:
        f_out.write("Frame Px_avg Py_avg Pz_avg\n") 

        for i in range(nframes):
            frame_data = data[i * frame_size : (i + 1) * frame_size]
            avg_px = np.mean(frame_data[:, 0])
            avg_py = np.mean(frame_data[:, 1])
            avg_pz = np.mean(frame_data[:, 2])
            f_out.write(f"{i + 1} {avg_px:.8f} {avg_py:.8f} {avg_pz:.8f}\n")

if __name__ == "__main__":
    main()

3 Correlation function calculation

We use the following script cal_spatial_correlation.py with a GPU cluster to calculate the correlation function.

Important notes: - The input file must be in .npy format - <uc_x>, <uc_y>, and <uc_z> represent the supercell dimensions along the three axes - <max_tau> specifies the maximum time delay for correlation function calculation (the script will compute values for \(\tau\) ranging from 0 to <max_tau>) - This script must be executed on a GPU cluster. You need to specify the number of GPUs in your sbatch submission script.

#usage: python cal_spatial_correlation.py <input_file> <output_file> <uc_x> <uc_y> <uc_z> <max_tau>

import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import multiprocessing
from multiprocessing import Process, Queue
import sys

def set_gpu(gpu_id):
    torch.cuda.set_device(gpu_id)
    return torch.device(f'cuda:{gpu_id}')

inputfile = sys.argv[1]
outputfile = sys.argv[2]
GRID_SIZE_X = int(sys.argv[3])
GRID_SIZE_Y = int(sys.argv[4])
GRID_SIZE_Z = int(sys.argv[5])
TAU_MAX = int(sys.argv[6])

def coord_to_index(x, y, z):
    return x * GRID_SIZE_Y * GRID_SIZE_Z + y * GRID_SIZE_Z + z

def index_to_coord(idx):
    x = idx // (GRID_SIZE_Y * GRID_SIZE_Z)
    y = (idx // GRID_SIZE_Z) % GRID_SIZE_Y
    z = idx % GRID_SIZE_Z
    return x, y, z

def generate_r_dict():
    max_dim = max(GRID_SIZE_X, GRID_SIZE_Y, GRID_SIZE_Z)
    search_range = max_dim // 2 
    r_values = []
    for i in range(1, search_range + 1):
        r_values.extend([i, i*np.sqrt(2), i*np.sqrt(3)])
    r_values = np.round(np.unique(r_values), 6)

    displacements = []
    for dx in range(-search_range, search_range + 1):
        for dy in range(-search_range, search_range + 1):
            for dz in range(-search_range, search_range + 1):
                if dx == 0 and dy == 0 and dz == 0:
                    continue
                displacements.append((dx, dy, dz))

    r_dict = {}
    for dx, dy, dz in displacements:
        r = np.sqrt(dx**2 + dy**2 + dz**2)
        r_rounded = round(r, 6)
        if r_rounded in r_values:
            if r_rounded not in r_dict:
                r_dict[r_rounded] = []
            r_dict[r_rounded].append((dx, dy, dz))

    return r_dict, sorted(r_dict.keys())

def precompute_neighbors(r_dict):
    neighbor_cache = {}
    for r in r_dict:
        neighbor_cache[r] = {}
        for dx, dy, dz in r_dict[r]:
            key = (dx, dy, dz)
            indices = []
            for idx in range(GRID_SIZE_X * GRID_SIZE_Y * GRID_SIZE_Z):
                x, y, z = index_to_coord(idx)
                nx = (x + dx) % GRID_SIZE_X
                ny = (y + dy) % GRID_SIZE_Y
                nz = (z + dz) % GRID_SIZE_Z
                indices.append(coord_to_index(nx, ny, nz))
            neighbor_cache[r][key] = torch.LongTensor(indices)
    return neighbor_cache

def compute_correlation_gpu_subset(data, r_dict, sorted_r, neighbor_cache_cpu, start_tau, end_tau, gpu_id, result_queue):
    try:
        device = set_gpu(gpu_id)
        data_tensor = torch.from_numpy(data).float().to(device)
        if torch.isnan(data_tensor).any():
            raise ValueError("Input data contains NaN values")

        N = data_tensor.shape[0]
        neighbor_cache = {}

        for r in neighbor_cache_cpu:
            neighbor_cache[r] = {}
            for key in neighbor_cache_cpu[r]:
                neighbor_cache[r][key] = neighbor_cache_cpu[r][key].to(device)

        norms = torch.norm(data_tensor, dim=2, keepdim=True)
        valid_mask = (norms > 1e-6).float()  
        data_tensor = (data_tensor * valid_mask) / (norms + 1e-8)

        effective_taus = [tau for tau in range(start_tau, end_tau+1) if tau < N]
        if not effective_taus:
            raise ValueError(f"No valid taus in range [{start_tau}, {end_tau}]")

        progress_desc = f"GPU{gpu_id} τ({start_tau}-{end_tau})"
        pbar = tqdm(total=len(effective_taus), desc=progress_desc, position=gpu_id+1)

        C_part = np.full((len(sorted_r), len(effective_taus)), np.nan)

        with torch.no_grad():
            for local_idx, tau in enumerate(effective_taus):
                valid_t = N - tau
                current = data_tensor[:valid_t]
                future = data_tensor[tau:tau+valid_t]

                for r_idx, r in enumerate(sorted_r):
                    total = 0.0
                    count = 0

                    if r not in neighbor_cache:
                        continue

                    for (dx, dy, dz), neighbors in neighbor_cache[r].items():
                        future_shifted = future[:, neighbors, :]
                        dots = torch.sum(current * future_shifted, dim=2)

                        valid_dots = dots[~torch.isnan(dots)]
                        if valid_dots.numel() > 0:
                            total += valid_dots.sum().item()
                            count += valid_dots.numel()

                    if count > 0:
                        C_part[r_idx, local_idx] = total / count

                pbar.update()
                pbar.set_postfix_str(f"τ={tau}")

        pbar.close()
        result_queue.put({
            'start_tau': start_tau,
            'end_tau': end_tau,
            'C_part': C_part,
            'valid_taus': effective_taus
        })

    except Exception as e:
        print(f"\nGPU {gpu_id} Error: {str(e)}", flush=True)
        result_queue.put(None)

def split_tau_by_workload(TAU_MAX, N, num_gpus):
    effective_TAU_MAX = min(TAU_MAX, N-1)
    workloads = [N - tau for tau in range(effective_TAU_MAX + 1)]
    total_workload = sum(workloads)
    target_per_gpu = total_workload / num_gpus

    splits = []
    current_sum = 0
    start = 0
    for tau in range(effective_TAU_MAX + 1):
        current_sum += workloads[tau]
        if current_sum >= target_per_gpu or tau == effective_TAU_MAX:
            splits.append((start, tau))
            start = tau + 1
            current_sum = 0
            if len(splits) == num_gpus:
                break
    if start <= effective_TAU_MAX:
        splits.append((start, effective_TAU_MAX))
    return splits[:num_gpus]

if __name__ == "__main__":
    multiprocessing.set_start_method('spawn')  

    data = np.load(inputfile)
    N = data.shape[0]
    print(f"Data shape: {data.shape}, N={N}")
    print(f"System size: X={GRID_SIZE_X}, Y={GRID_SIZE_Y}, Z={GRID_SIZE_Z}")

    r_dict, sorted_r = generate_r_dict()
    neighbor_cache_cpu = precompute_neighbors(r_dict)

    num_gpus = torch.cuda.device_count()
    print(f"Available GPUs: {num_gpus}")

    splits = split_tau_by_workload(TAU_MAX, N, num_gpus)
    print(f"Task splits: {splits}")

    result_queue = Queue()
    processes = []

    for gpu_id in range(num_gpus):
        if gpu_id >= len(splits):
            break
        start, end = splits[gpu_id]
        p = Process(target=compute_correlation_gpu_subset,
                   args=(data, r_dict, sorted_r, neighbor_cache_cpu, 
                         start, end, gpu_id, result_queue))
        processes.append(p)
        p.start()

    main_pbar = tqdm(total=len(processes), desc="Overall Progress", position=0)

    C = np.full((len(sorted_r), TAU_MAX+1), np.nan)
    completed = 0
    while completed < len(processes):
        result = result_queue.get()
        if result is None:
            completed += 1
            continue

        for r_idx in range(len(sorted_r)):
            for local_idx, tau in enumerate(result['valid_taus']):
                if tau <= TAU_MAX:
                    C[r_idx, tau] = result['C_part'][r_idx, local_idx]

        completed += 1
        main_pbar.update()

    main_pbar.close()

    with open(outputfile, "w") as f:
        for r_idx, r in enumerate(sorted_r):
            for tau in range(TAU_MAX+1):
                f.write(f"{tau}\t{r}\t{C[r_idx, tau]}\n")

    print("Calculation completed successfully.")

4 Postprocessing (optional)

You can use the following scripts the visualize the result at a certain temperature:

import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
import numpy as np
from scipy.interpolate import interp1d, UnivariateSpline
from matplotlib import rcParams

plt.style.use('nature.mplstyle')
FIGURE_SIZE = (4, 3)          
FONT_SIZE = 15                
LINE_WIDTH = 3
AXIS_LABELS = {               
    'x': '$r$ (Å)', 
    'y': r'$C(r,\tau)$'
}
LEGEND_POS = 'upper right'

custom_colors = ['#403990', '#80A6E2', '#FBDD85', '#F46F43', '#CF3D3E']
color_dict = {tau: color for tau, color in zip([0, 10, 100, 500, 999], custom_colors)}

df = pd.read_csv('spatial_a_disp_correlation_standard.txt', sep='\s+', header=None, names=['tau', 'r', 'C'])

target_taus = [0, 10, 100, 500, 999]
filtered_df = df[df['tau'].isin(target_taus)]

fig, ax = plt.subplots(figsize=FIGURE_SIZE)
ax2 = ax.twinx()  
angle_ticks = [45, 60, 75, 90]
cosine_vals = [np.cos(np.deg2rad(ang)) for ang in angle_ticks]
ax2.set_yticks(cosine_vals)
ax2.set_yticklabels([f'{ang}°' for ang in angle_ticks])
ax2.set_ylabel('Angle', fontsize=FONT_SIZE+1, labelpad=5)
ax2.set_ylim(-0.06, 0.72)  

def filter_close_points(r_values):
    filtered = []
    prev = None
    for i, r in enumerate(r_values):
        if prev is not None and abs(r - prev) < 5:
            continue
        filtered.append(i)
        prev = r
    return filtered

# Function to smooth data (similar to reference script)
def smooth_data(r_values, c_values, method='spline', k=5, s=0.001):
    x_fine = np.linspace(min(r_values), max(r_values), 500) 
    try:
        spline = UnivariateSpline(r_values, c_values, k=k, s=s)
        y_smooth = spline(x_fine)
        return x_fine, y_smooth
    except Exception:
        interpolator = interp1d(r_values, c_values, kind='linear', fill_value="extrapolate")
        return x_fine, interpolator(x_fine)


legend_handles = []
for tau in target_taus:
    tau_data = filtered_df[filtered_df['tau'] == tau].sort_values('r')
    if tau_data.empty:
        continue

    r_values = tau_data['r'].values * 4  
    c_values = tau_data['C'].values
    color = color_dict[tau]

    idx = filter_close_points(r_values)
    ax.scatter(
        r_values[idx], 
        c_values[idx],
        s=25,
        color=color,
        alpha=1.0,
        linewidths=1.0,
        marker='s',
        facecolors='none',
        edgecolors=color,
        zorder=3
    )

    x_smooth, y_smooth = smooth_data(r_values, c_values)
    curve, = ax.plot(
        x_smooth, 
        y_smooth,
        lw=LINE_WIDTH-1,
        color=color,
        alpha=0.85,
        label=fr'$\tau = {tau}$',
        zorder=2
    )
    legend_handles.append(curve)

ax.set_xlim(0, 120)
ax.set_xticks([0, 20, 40, 60, 80, 100, 120])
ax.set_ylim(-0.06, 0.72)
ax.set_yticks([0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
ax.set_yticklabels([f"{t:.1f}" for t in [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]])

ax.set_xlabel(AXIS_LABELS['x'], fontsize=FONT_SIZE+1, labelpad=5)
ax.set_ylabel(AXIS_LABELS['y'], fontsize=FONT_SIZE+1, labelpad=5)
ax.tick_params(axis='both', which='major', labelsize=FONT_SIZE-1, pad=4)
ax2.tick_params(axis='y', which='major', labelsize=FONT_SIZE-1, pad=4)

leg = ax.legend(
    handles=legend_handles,
    loc=LEGEND_POS,
    frameon=False,
    fontsize=FONT_SIZE-4,
    handlelength=2.5
)
leg.get_frame().set_edgecolor('0.8')

plt.tight_layout()
plt.subplots_adjust(right=0.85) 
plt.savefig('C_vs_r_p.png', dpi=300, bbox_inches='tight')
plt.show()

alt text

Or you can average all \(\tau\) values into \(C(r)\), which is defined by \(\(C(r)=\int C(r,\tau) d\tau\)\)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d, UnivariateSpline
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import rcParams
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

plt.style.use('nature.mplstyle')  
INPUT_DIRS = ['100','200','250','300','400']  
OUTPUT_FILE = "correlation_r_a.txt"  #Correlation function result                 

custom_colors = ['#403990', '#80A6E2', '#FBDD85', '#F46F43', '#CF3D3E']
color_dict = {int(temp): color for temp, color in zip(INPUT_DIRS, custom_colors)}

FIGURE_SIZE = (4, 3)          
FONT_SIZE = 15                
LINE_WIDTH = 3              
AXIS_LABELS = {               
    'x': '$r$ (Å)', 
    'y': 'r$C(r,\overline{\tau})$'
}
LEGEND_POS = 'upper right'     

def process_data():
    all_data = []

    for temp_dir in INPUT_DIRS:            
        file_path = os.path.join(temp_dir, "spatial_a_disp_correlation_standard.txt")
        df = pd.read_csv(
                file_path, 
                sep='\t', 
                header=None, 
                names=['tau', 'r', 'C'],
                dtype={'r': float, 'C': float}
            )

        mean_df = df.groupby('r')['C'].mean().reset_index()
        mean_df['temp'] = int(temp_dir)

        all_data.append(mean_df)

    if not all_data:
        return pd.DataFrame()

    final_df = pd.concat(all_data, ignore_index=True)
    final_df.to_csv(OUTPUT_FILE, sep='\t', index=False, 
                        columns=['temp', 'r', 'C'],
                        header=['Temperature (K)', 'r', '⟨C⟩'])
    return final_df

def smooth_data(r_values, c_values, method='spline', k=5, s=0.001):
    x_fine = np.linspace(min(r_values), max(r_values), 500) 
    try:
        spline = UnivariateSpline(r_values, c_values, k=k, s=s)
        y_smooth = spline(x_fine)
        return x_fine, y_smooth
    except Exception:
        interpolator = interp1d(r_values, c_values, kind='linear', fill_value="extrapolate")
        return x_fine, interpolator(x_fine)

def filter_close_points(r_values):
    filtered = []
    prev = None
    for i, r in enumerate(r_values):
        if prev is not None and abs(r - prev) < 5:
            continue
        filtered.append(i)
        prev = r
    return filtered

def plot_results(df):
    plt.figure(figsize=FIGURE_SIZE)
    ax = plt.gca()
    ax2 = ax.twinx()
    angle_ticks = [45, 60, 75, 90]
    cosine_vals = [np.cos(np.deg2rad(ang)) for ang in angle_ticks]

    ax2.set_yticks(cosine_vals)
    ax2.set_yticklabels([f'{ang}°' for ang in angle_ticks])
    ax2.set_ylabel('Angle', fontsize=FONT_SIZE+1, labelpad=5)

    legend_handles = [] 
    for temp in np.sort(df['Temperature (K)'].unique()):
        temp_data = df[df['Temperature (K)'] == temp].sort_values('r')
        color = color_dict[temp]  
        r_values = temp_data['r'].values * 4
        c_values = temp_data['⟨C⟩'].values

        idx = filter_close_points(r_values)

        plt.scatter(
            r_values[idx], 
            c_values[idx],
            s=25,
            color=color,
            alpha=1.0,
            linewidths=1.0,
            marker='s',
            facecolors='none',
            edgecolors=color,
            zorder=3
        )

        x_smooth, y_smooth = smooth_data(r_values, c_values, method='spline')

        curve, = plt.plot(
            x_smooth, 
            y_smooth,
            lw=LINE_WIDTH-1,
            color=color,
            alpha=0.85,
            label=f"{temp} K",
            zorder=2
        )
        legend_handles.append(curve)

    ax.set_xlim(0, 120)  
    x_ticks = [0, 20, 40, 60, 80, 100, 120]
    ax.set_xticks(x_ticks)
    ax.set_ylim(-0.06, 0.72)     
    ax.set_xlabel(AXIS_LABELS['x'], fontsize=FONT_SIZE+1, labelpad=5)
    ax.set_ylabel(r'$C^{\text{d}}(r)$', fontsize=FONT_SIZE+1, labelpad=5)
    y_ticks = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
    ax.set_yticks(y_ticks)
    ax.set_yticklabels([f"{t:.1f}" for t in y_ticks]) 
    ax2.set_ylim(-0.06, 0.72)
    ax.tick_params(axis='both', which='major', labelsize=FONT_SIZE-1, pad=4)
    ax2.tick_params(axis='y', which='major', labelsize=FONT_SIZE-1, pad=4)
    leg = ax.legend(
        handles=legend_handles,
        loc=LEGEND_POS, 
        frameon=False,
        fontsize=FONT_SIZE-4,
        handlelength=2.5
    )
    leg.get_frame().set_edgecolor('0.8')
    plt.tight_layout()
    plt.subplots_adjust(right=0.85) 
    plt.savefig("a_disp_correlation.png", dpi=300, bbox_inches='tight')
    plt.close()

if __name__ == "__main__":
    df = process_data()
    plot_df = pd.read_csv(OUTPUT_FILE, sep='\t')
    plot_results(plot_df)

Here is the figure: alt text