import re
from pathlib import Path

import pandas as pd
import matplotlib.pyplot as plt


# ============================================================
# 1. 用户需要根据自己的 in 文件确认的参数
# ============================================================

LOG_FILE = "log.lammps"

# 取第几段 thermo 数据：
# 第 1 段：run 30000，平衡阶段
# 第 2 段：run 100000，拉伸阶段
SELECT_BLOCK = 2

# 初始长度，来自你的 region box 的 x 方向长度
L0 = 119.295      # Å

# timestep，来自你的 in 文件
DT = 0.005        # ps

# 右端拉伸速度，来自 fix pull right move linear 0.02 0.0 0.0
VPULL = 0.02      # Å/ps

# LAMMPS 中 pxx 是压力，拉伸应力通常取 -pxx
# 如果你发现画出来方向反了，就把 -1 改成 1
STRESS_SIGN = -1


# ============================================================
# 2. 读取 LAMMPS log 文件中的 thermo 数据块
# ============================================================

def is_number(s):
    """判断字符串是否可以转换成数字"""
    try:
        float(s)
        return True
    except ValueError:
        return False


def read_lammps_thermo_blocks(log_file):
    """
    从 LAMMPS log 文件中读取所有 thermo 数据块。
    每遇到一行以 Step 开头的表头，就认为是一个新的 thermo block。
    """
    log_path = Path(log_file)

    if not log_path.exists():
        raise FileNotFoundError(f"找不到文件：{log_file}")

    lines = log_path.read_text(encoding="utf-8", errors="ignore").splitlines()

    blocks = []
    i = 0

    while i < len(lines):
        line = lines[i].strip()

        # 找到 thermo 表头，例如：
        # Step Temp Pxx Pyy Pzz
        if re.match(r"^Step\s+", line):
            columns = line.split()
            data = []

            i += 1

            while i < len(lines):
                current = lines[i].strip()

                if not current:
                    i += 1
                    continue

                parts = current.split()

                # 如果这一行全是数字，说明是 thermo 数据
                if len(parts) >= len(columns) and all(is_number(x) for x in parts[:len(columns)]):
                    data.append([float(x) for x in parts[:len(columns)]])
                    i += 1
                else:
                    break

            if data:
                df = pd.DataFrame(data, columns=columns)
                blocks.append(df)

        i += 1

    return blocks


# ============================================================
# 3. 主程序
# ============================================================

def main():
    blocks = read_lammps_thermo_blocks(LOG_FILE)

    print(f"在 {LOG_FILE} 中共找到 {len(blocks)} 段 thermo 数据。")

    if len(blocks) < SELECT_BLOCK:
        raise RuntimeError(
            f"你设置的是提取第 {SELECT_BLOCK} 段 thermo 数据，"
            f"但是 log 文件中只找到了 {len(blocks)} 段。"
        )

    # 取第二段，也就是拉伸阶段
    df = blocks[SELECT_BLOCK - 1].copy()

    print("当前提取的数据列为：")
    print(df.columns.tolist())

    # 检查必须存在的列
    required_cols = ["Step", "Temp", "Pxx", "Pyy", "Pzz"]
    for col in required_cols:
        if col not in df.columns:
            raise RuntimeError(f"log 文件中没有找到列：{col}")

    # ========================================================
    # 4. 计算应变
    # ========================================================
    # strain = v * t / L0
    # t = step * timestep
    df["strain"] = VPULL * df["Step"] * DT / L0
    df["strain_percent"] = df["strain"] * 100.0

    # ========================================================
    # 5. 单位换算
    # ========================================================
    # metal 单位制下：
    # Pxx, Pyy, Pzz 的单位是 bar
    # 1 GPa = 10000 bar
    #
    # LAMMPS 的 Pxx 是压力，拉伸应力一般取 -Pxx
    # ========================================================

    df["Sxx_GPa"] = STRESS_SIGN * df["Pxx"] / 10000.0
    df["Syy_GPa"] = STRESS_SIGN * df["Pyy"] / 10000.0
    df["Szz_GPa"] = STRESS_SIGN * df["Pzz"] / 10000.0

    # 同时保存原始压力换算值，方便你对比
    df["Pxx_GPa_raw"] = df["Pxx"] / 10000.0
    df["Pyy_GPa_raw"] = df["Pyy"] / 10000.0
    df["Pzz_GPa_raw"] = df["Pzz"] / 10000.0

    # ========================================================
    # 6. 保存提取后的数据
    # ========================================================

    output_csv = "stress_strain_data.csv"
    df.to_csv(output_csv, index=False, encoding="utf-8-sig")

    print(f"数据已保存为：{output_csv}")
    print()
    print("前 5 行数据预览：")
    print(df[["Step", "Temp", "strain", "strain_percent", "Sxx_GPa", "Syy_GPa", "Szz_GPa"]].head())

    # ========================================================
    # 7. 绘制 Sxx 应力-应变曲线
    # ========================================================

    plt.figure(figsize=(7, 5), dpi=150)
    plt.plot(df["strain"], df["Sxx_GPa"], linewidth=1.8, label="Sxx")

    plt.xlabel("Engineering strain")
    plt.ylabel("Stress Sxx / GPa")
    plt.title("Stress-Strain Curve of Cu Tensile Simulation")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()

    fig1 = "stress_strain_Sxx.png"
    plt.savefig(fig1, dpi=300)
    plt.show()

    print(f"Sxx 应力-应变曲线已保存为：{fig1}")

    # ========================================================
    # 8. 绘制 Sxx, Syy, Szz 三个方向
    # ========================================================

    plt.figure(figsize=(7, 5), dpi=150)
    plt.plot(df["strain"], df["Sxx_GPa"], linewidth=1.8, label="Sxx")
    plt.plot(df["strain"], df["Syy_GPa"], linewidth=1.5, label="Syy")
    plt.plot(df["strain"], df["Szz_GPa"], linewidth=1.5, label="Szz")

    plt.xlabel("Engineering strain")
    plt.ylabel("Stress / GPa")
    plt.title("Stress Components During Tensile Simulation")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()

    fig2 = "stress_components.png"
    plt.savefig(fig2, dpi=300)
    plt.show()

    print(f"三个方向应力曲线已保存为：{fig2}")


if __name__ == "__main__":
    main()
