// =============== 内核密钥管理 (kernel/rsa.h) ===============

#ifndef RSA_H
#define RSA_H

// RSA密钥结构（存储在内核空间）
struct rsa_keypair {
    uint32 n;      // 模数 n = p * q
    uint32 e;      // 公钥指数
    uint32 d;      // 私钥指数（仅内核可见）
    uint32 p;      // 素数p（用于密钥生成，仅内核可见）
    uint32 q;      // 素数q（用于密钥生成，仅内核可见）
    uint32 phi_n;  // 欧拉函数值 φ(n) = (p-1)(q-1)
    int valid;     // 密钥是否有效
};

// 小素数表（用于密钥生成）
static const uint16 small_primes[] = {
    61, 67, 71, 73, 79, 83, 89, 97, 101, 103,
    107, 109, 113, 127, 131, 137, 139, 149, 151, 157,
    163, 167, 173, 179, 181, 191, 193, 197, 199, 211
};

#endif

// =============== 进程结构体修改 (kernel/proc.h) ===============

// 在 struct proc 中添加：
struct proc {
    // ... 现有字段
    
    // RSA密钥对（每个进程独有）
    struct rsa_keypair rsa_keys;
    
    // 密钥访问权限控制
    int key_permissions;  // 0: 无权限, 1: 只读公钥, 2: 完全权限
};

// =============== 内核RSA实现 (kernel/rsa.c) ===============

#include "types.h"
#include "param.h"
#include "memlayout.h"
#include "riscv.h"
#include "spinlock.h"
#include "proc.h"
#include "defs.h"
#include "rsa.h"

// 简单的伪随机数生成器（基于系统时钟）
static uint32 rsa_rand_seed = 0;

uint32 rsa_random() {
    // 使用线性同余生成器 + 系统ticks
    rsa_rand_seed = (rsa_rand_seed * 1664525 + 1013904223) ^ ticks;
    return rsa_rand_seed;
}

// Miller-Rabin素性测试
int is_prime_miller_rabin(uint32 n, int k) {
    if (n < 2) return 0;
    if (n == 2 || n == 3) return 1;
    if (n % 2 == 0) return 0;
    
    // 将 n-1 表示为 d * 2^r
    uint32 d = n - 1;
    int r = 0;
    while (d % 2 == 0) {
        d /= 2;
        r++;
    }
    
    // 进行k轮测试
    for (int i = 0; i < k; i++) {
        uint32 a = 2 + (rsa_random() % (n - 3));
        uint32 x = mod_pow(a, d, n);
        
        if (x == 1 || x == n - 1) continue;
        
        int composite = 1;
        for (int j = 0; j < r - 1; j++) {
            x = mod_pow(x, 2, n);
            if (x == n - 1) {
                composite = 0;
                break;
            }
        }
        
        if (composite) return 0;
    }
    
    return 1;
}

// 生成素数
uint32 generate_prime(int bits) {
    uint32 min_val = 1 << (bits - 1);
    uint32 max_val = (1 << bits) - 1;
    
    while (1) {
        uint32 candidate = min_val + (rsa_random() % (max_val - min_val));
        candidate |= 1;  // 确保是奇数
        
        if (is_prime_miller_rabin(candidate, 10)) {
            return candidate;
        }
    }
}

// 扩展欧几里得算法
int extended_gcd(int a, int b, int *x, int *y) {
    if (a == 0) {
        *x = 0;
        *y = 1;
        return b;
    }
    
    int x1, y1;
    int gcd = extended_gcd(b % a, a, &x1, &y1);
    
    *x = y1 - (b / a) * x1;
    *y = x1;
    
    return gcd;
}

// 计算模逆元
uint32 mod_inverse(uint32 a, uint32 m) {
    int x, y;
    int gcd = extended_gcd(a, m, &x, &y);
    
    if (gcd != 1) return 0;  // 不存在模逆元
    
    return (x % m + m) % m;
}

// 快速模幂运算
uint32 mod_pow(uint32 base, uint32 exp, uint32 mod) {
    uint64 result = 1;
    uint64 b = base % mod;
    
    while (exp > 0) {
        if (exp & 1) {
            result = (result * b) % mod;
        }
        b = (b * b) % mod;
        exp >>= 1;
    }
    
    return (uint32)result;
}

// 为进程生成新的RSA密钥对
int generate_rsa_keypair(struct rsa_keypair *keypair) {
    // 初始化随机种子
    if (rsa_rand_seed == 0) {
        rsa_rand_seed = ticks ^ (uint32)((uint64)&keypair);
    }
    
    // 生成两个不同的素数（8-10位二进制数）
    uint32 p = generate_prime(9);
    uint32 q = generate_prime(10);
    
    // 确保p != q
    while (p == q) {
        q = generate_prime(10);
    }
    
    // 计算n和φ(n)
    keypair->p = p;
    keypair->q = q;
    keypair->n = p * q;
    keypair->phi_n = (p - 1) * (q - 1);
    
    // 选择公钥指数e（通常使用65537，但我们用小一点的）
    keypair->e = 17;  // 常用的小公钥指数
    
    // 确保gcd(e, φ(n)) = 1
    int x, y;
    while (extended_gcd(keypair->e, keypair->phi_n, &x, &y) != 1) {
        keypair->e += 2;  // 保持为奇数
    }
    
    // 计算私钥指数d
    keypair->d = mod_inverse(keypair->e, keypair->phi_n);
    
    if (keypair->d == 0) {
        return -1;  // 密钥生成失败
    }
    
    keypair->valid = 1;
    
    // 验证密钥对
    uint32 test_msg = 42;
    uint32 encrypted = mod_pow(test_msg, keypair->e, keypair->n);
    uint32 decrypted = mod_pow(encrypted, keypair->d, keypair->n);
    
    if (test_msg != decrypted) {
        keypair->valid = 0;
        return -1;
    }
    
    return 0;
}

// 初始化进程的RSA密钥（在fork时调用）
void init_proc_rsa(struct proc *p) {
    // 为新进程生成独立的密钥对
    if (generate_rsa_keypair(&p->rsa_keys) < 0) {
        // 如果生成失败，使用备用密钥
        p->rsa_keys.p = 61;
        p->rsa_keys.q = 53;
        p->rsa_keys.n = 3233;
        p->rsa_keys.e = 17;
        p->rsa_keys.d = 2753;
        p->rsa_keys.phi_n = 3120;
        p->rsa_keys.valid = 1;
    }
    
    // 默认权限：可以使用自己的密钥
    p->key_permissions = 2;
}

// 使用进程的公钥加密数据（任何人都可以调用）
int rsa_encrypt_with_proc_key(struct proc *p, char *data, int len) {
    if (!p->rsa_keys.valid) {
        return -1;
    }
    
    for (int i = 0; i < len; i++) {
        uint32 plaintext = (uint32)(unsigned char)data[i];
        uint32 ciphertext = mod_pow(plaintext, p->rsa_keys.e, p->rsa_keys.n);
        // 简化处理：只保留低字节（实际应该扩展存储）
        data[i] = (char)(ciphertext & 0xFF);
    }
    
    return len;
}

// 使用进程的私钥解密数据（仅内核可以调用）
int rsa_decrypt_with_proc_key(struct proc *p, char *data, int len) {
    if (!p->rsa_keys.valid) {
        return -1;
    }
    
    // 检查进程是否有权限使用私钥
    if (p->key_permissions < 2) {
        return -1;
    }
    
    for (int i = 0; i < len; i++) {
        uint32 ciphertext = (uint32)(unsigned char)data[i];
        uint32 plaintext = mod_pow(ciphertext, p->rsa_keys.d, p->rsa_keys.n);
        data[i] = (char)plaintext;
    }
    
    return len;
}

// =============== 系统调用实现 (kernel/sysfile.c) ===============

// 生成新的RSA密钥对（用户可调用）
uint64 sys_genrsakey(void) {
    struct proc *p = myproc();
    
    // 生成新密钥对
    if (generate_rsa_keypair(&p->rsa_keys) < 0) {
        return -1;
    }
    
    // 返回公钥信息给用户（但不返回私钥）
    printf("New RSA keypair generated for PID %d\n", p->pid);
    printf("Public key (n=%d, e=%d)\n", p->rsa_keys.n, p->rsa_keys.e);
    // 私钥d永远不会暴露给用户空间
    
    return 0;
}

// 获取当前进程的公钥
uint64 sys_getpubkey(void) {
    struct proc *p = myproc();
    uint64 addr;
    
    if (argaddr(0, &addr) < 0) {
        return -1;
    }
    
    if (!p->rsa_keys.valid) {
        return -1;
    }
    
    // 只返回公钥部分（n和e）
    struct {
        uint32 n;
        uint32 e;
    } pubkey;
    
    pubkey.n = p->rsa_keys.n;
    pubkey.e = p->rsa_keys.e;
    
    // 复制到用户空间
    if (copyout(p->pagetable, addr, (char*)&pubkey, sizeof(pubkey)) < 0) {
        return -1;
    }
    
    return 0;
}

// 加密文件（使用目标进程的公钥）
uint64 sys_encryptfile(void) {
    char path[MAXPATH];
    int target_pid;
    struct proc *target_proc;
    struct inode *ip;
    
    if (argstr(0, path, MAXPATH) < 0 || argint(1, &target_pid) < 0) {
        return -1;
    }
    
    // 查找目标进程
    target_proc = 0;
    for (struct proc *p = proc; p < &proc[NPROC]; p++) {
        acquire(&p->lock);
        if (p->pid == target_pid && p->state != UNUSED) {
            target_proc = p;
            release(&p->lock);
            break;
        }
        release(&p->lock);
    }
    
    if (!target_proc || !target_proc->rsa_keys.valid) {
        return -1;
    }
    
    begin_op();
    
    if ((ip = namei(path)) == 0) {
        end_op();
        return -1;
    }
    
    ilock(ip);
    
    if (ip->type != T_FILE) {
        iunlockput(ip);
        end_op();
        return -1;
    }
    
    // 读取文件内容
    char buf[512];
    int n = readi(ip, 0, (uint64)buf, 0, sizeof(buf));
    
    if (n > 0) {
        // 使用目标进程的公钥加密
        rsa_encrypt_with_proc_key(target_proc, buf, n);
        
        // 写回加密内容
        writei(ip, 0, (uint64)buf, 0, n);
        
        // 记录加密元数据（可选：在文件属性中标记）
        printf("File %s encrypted for PID %d\n", path, target_pid);
    }
    
    iunlockput(ip);
    end_op();
    
    return n;
}

// 解密文件（仅文件所有者可以解密）
uint64 sys_decryptfile(void) {
    char path[MAXPATH];
    struct proc *p = myproc();
    struct inode *ip;
    
    if (argstr(0, path, MAXPATH) < 0) {
        return -1;
    }
    
    // 检查进程是否有权限使用私钥
    if (!p->rsa_keys.valid || p->key_permissions < 2) {
        printf("No permission to use private key\n");
        return -1;
    }
    
    begin_op();
    
    if ((ip = namei(path)) == 0) {
        end_op();
        return -1;
    }
    
    ilock(ip);
    
    if (ip->type != T_FILE) {
        iunlockput(ip);
        end_op();
        return -1;
    }
    
    char buf[512];
    int n = readi(ip, 0, (uint64)buf, 0, sizeof(buf));
    
    if (n > 0) {
        // 使用当前进程的私钥解密（私钥永远不离开内核）
        if (rsa_decrypt_with_proc_key(p, buf, n) < 0) {
            printf("Decryption failed - wrong key?\n");
            iunlockput(ip);
            end_op();
            return -1;
        }
        
        // 写回解密内容
        writei(ip, 0, (uint64)buf, 0, n);
        
        printf("File %s decrypted by PID %d\n", path, p->pid);
    }
    
    iunlockput(ip);
    end_op();
    
    return n;
}

// 撤销进程的私钥访问权限（安全特性）
uint64 sys_revoke_key(void) {
    struct proc *p = myproc();
    p->key_permissions = 1;  // 降级为只能使用公钥
    printf("Private key access revoked for PID %d\n", p->pid);
    return 0;
}

// =============== 进程初始化修改 (kernel/proc.c) ===============

// 在 allocproc() 函数中添加：
static struct proc* allocproc(void) {
    struct proc *p;
    
    // ... 现有代码
    
    // 初始化RSA密钥
    init_proc_rsa(p);
    
    return p;
}

// 在 fork() 中决定是否继承父进程密钥
int fork(void) {
    // ... 现有代码
    
    // 子进程生成新的密钥对（不继承父进程的私钥）
    init_proc_rsa(np);
    
    // ... 现有代码
}

// =============== 用户空间测试程序 (user/rsatest.c) ===============

#include "kernel/types.h"
#include "kernel/stat.h"
#include "user/user.h"
#include "kernel/fcntl.h"

struct pubkey {
    uint n;
    uint e;
};

int main(int argc, char *argv[]) {
    printf("=== xv6 Secure RSA System Test ===\n\n");
    
    // 1. 生成当前进程的RSA密钥对
    printf("Step 1: Generating RSA keypair for current process...\n");
    if (genrsakey() < 0) {
        printf("Failed to generate RSA keypair\n");
        exit(1);
    }
    
    // 2. 获取公钥（私钥在内核中，用户无法访问）
    struct pubkey my_pubkey;
    if (getpubkey(&my_pubkey) < 0) {
        printf("Failed to get public key\n");
        exit(1);
    }
    printf("My public key: n=%d, e=%d\n", my_pubkey.n, my_pubkey.e);
    printf("(Private key is securely stored in kernel space)\n\n");
    
    // 3. 创建测试文件
    printf("Step 2: Creating test file...\n");
    char *filename = "secret.txt";
    int fd = open(filename, O_CREATE | O_WRONLY);
    if (fd < 0) {
        printf("Cannot create file\n");
        exit(1);
    }
    
    char *secret_data = "This is a secret message!";
    write(fd, secret_data, strlen(secret_data));
    close(fd);
    printf("Created file with content: %s\n\n", secret_data);
    
    // 4. Fork一个子进程来演示进程间加密
    printf("Step 3: Forking child process...\n");
    int pid = fork();
    
    if (pid == 0) {
        // 子进程：生成自己的密钥
        printf("[Child] Generating my own RSA keypair...\n");
        genrsakey();
        
        struct pubkey child_pubkey;
        getpubkey(&child_pubkey);
        printf("[Child] My public key: n=%d, e=%d\n", child_pubkey.n, child_pubkey.e);
        
        // 等待父进程加密文件
        sleep(10);
        
        // 尝试解密（应该失败，因为使用不同的密钥）
        printf("[Child] Attempting to decrypt parent's file...\n");
        if (decryptfile(filename) < 0) {
            printf("[Child] ✓ Cannot decrypt - different private key!\n");
        } else {
            printf("[Child] ✗ Unexpected: decryption succeeded\n");
        }
        
        exit(0);
    }
    
    // 5. 父进程：加密文件（使用自己的公钥）
    printf("[Parent] Encrypting file for myself (PID %d)...\n", getpid());
    if (encryptfile(filename, getpid()) < 0) {
        printf("Encryption failed\n");
        exit(1);
    }
    
    // 6. 查看加密后的内容
    fd = open(filename, O_RDONLY);
    char encrypted[512];
    int n = read(fd, encrypted, sizeof(encrypted));
    close(fd);
    
    printf("[Parent] Encrypted content (hex): ");
    for (int i = 0; i < n && i < 10; i++) {
        printf("%02x ", (unsigned char)encrypted[i]);
    }
    printf("...\n\n");
    
    // 7. 解密文件（使用自己的私钥 - 在内核中）
    printf("[Parent] Decrypting file with my private key...\n");
    if (decryptfile(filename) < 0) {
        printf("Decryption failed\n");
        exit(1);
    }
    
    // 8. 验证解密结果
    fd = open(filename, O_RDONLY);
    char decrypted[512];
    n = read(fd, decrypted, sizeof(decrypted));
    close(fd);
    decrypted[n] = '\0';
    
    printf("[Parent] Decrypted content: %s\n", decrypted);
    
    if (strcmp(secret_data, decrypted) == 0) {
        printf("\n✓ RSA encryption/decryption successful!\n");
        printf("✓ Private key never left kernel space!\n");
    } else {
        printf("\n✗ Decryption verification failed\n");
    }
    
    // 9. 测试密钥撤销
    printf("\nStep 4: Testing key revocation...\n");
    revokekey();
    printf("Private key access revoked\n");
    
    if (decryptfile(filename) < 0) {
        printf("✓ Cannot decrypt after key revocation\n");
    }
    
    // 等待子进程
    wait(0);
    
    printf("\n=== Test Complete ===\n");
    exit(0);
}

// =============== 系统调用声明 ===============

// user/user.h
int genrsakey(void);
int getpubkey(void*);
int encryptfile(char*, int);
int decryptfile(char*);
int revokekey(void);

// kernel/syscall.h
#define SYS_genrsakey   22
#define SYS_getpubkey   23
#define SYS_encryptfile 24
#define SYS_decryptfile 25
#define SYS_revoke_key  26

// user/usys.pl
entry("genrsakey");
entry("getpubkey");
entry("encryptfile");
entry("decryptfile");
entry("revokekey");
