#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <sys/types.h>
#include <fcntl.h>
#include <errno.h>
#include <unistd.h>
#include <inttypes.h>

static int print_interrupts = 0;
static int print_elf = 0;

static uint16_t uint16_t_load(const unsigned char *s) {
  uint16_t z = 0;
  z |= ((uint16_t) (*s++)) << 0;
  z |= ((uint16_t) (*s++)) << 8;
  return z;
}

static uint16_t uint16_t_load_bigendian(const unsigned char *s) {
  uint16_t z = 0;
  z |= ((uint16_t) (*s++)) << 8;
  z |= ((uint16_t) (*s++)) << 0;
  return z;
}

static uint32_t uint32_t_load(const unsigned char *s) {
  uint32_t z = 0;
  z |= ((uint32_t) (*s++)) << 0;
  z |= ((uint32_t) (*s++)) << 8;
  z |= ((uint32_t) (*s++)) << 16;
  z |= ((uint32_t) (*s++)) << 24;
  return z;
}

static uint32_t uint32_t_load_bigendian(const unsigned char *s) {
  uint32_t z = 0;
  z |= ((uint32_t) (*s++)) << 24;
  z |= ((uint32_t) (*s++)) << 16;
  z |= ((uint32_t) (*s++)) << 8;
  z |= ((uint32_t) (*s++)) << 0;
  return z;
}

static uint64_t uint64_t_load(const unsigned char *s) {
  uint64_t z = 0;
  z |= ((uint64_t) (*s++)) << 0;
  z |= ((uint64_t) (*s++)) << 8;
  z |= ((uint64_t) (*s++)) << 16;
  z |= ((uint64_t) (*s++)) << 24;
  z |= ((uint64_t) (*s++)) << 32;
  z |= ((uint64_t) (*s++)) << 40;
  z |= ((uint64_t) (*s++)) << 48;
  z |= ((uint64_t) (*s++)) << 56;
  return z;
}

static uint64_t uint64_t_load_bigendian(const unsigned char *s) {
  uint64_t z = 0;
  z |= ((uint64_t) (*s++)) << 56;
  z |= ((uint64_t) (*s++)) << 48;
  z |= ((uint64_t) (*s++)) << 40;
  z |= ((uint64_t) (*s++)) << 32;
  z |= ((uint64_t) (*s++)) << 24;
  z |= ((uint64_t) (*s++)) << 16;
  z |= ((uint64_t) (*s++)) << 8;
  z |= ((uint64_t) (*s++)) << 0;
  return z;
}

#include <time.h>

int clock_gettime(clockid_t clockid, struct timespec *tp)
{
  tp->tv_sec = 0;
  tp->tv_nsec = 0;
  return 0;
}

int gettimeofday(struct timeval *tv,void *tz)
{
  tv->tv_sec = 0;
  tv->tv_usec = 0;
  return 0;
}

#define PAGE 4096

#define ELF_CLASS_32 1
#define ELF_CLASS_64 2

// not all of these are supported yet
#define ELF_SPARC32 2
#define ELF_X86 3
#define ELF_MIPS32 8
#define ELF_SPARC32PLUS 18
#define ELF_PPC32 20
#define ELF_PPC64 21
#define ELF_ARM32 40
#define ELF_SPARC64 43
#define ELF_AMD64 62
#define ELF_ARM64 183

static char *fn = 0;
static int fd = -1;
static char *elf = 0;
static long long elf_len = 0;
static uint64_t elf_class = 0;
static uint64_t elf_endianness = 0;
static uint64_t elf_type = 0;
static uint64_t elf_machine = 0;
static uint64_t elf_version = 0;
static uint64_t elf_entry = 0;
static uint64_t elf_phoff = 0;
static uint64_t elf_flags = 0;
static uint64_t elf_phentsize = 0;
static uint64_t elf_phnum = 0;
static uint64_t elf_minaddr = 0;
static uint64_t elf_maxaddr = 0;

static uint8_t load8(long long pos)
{
  if (pos < 0 || pos+1 > elf_len) {
    fprintf(stderr,"%s misformatted, out-of-bounds read\n",fn);
    exit(111);
  }
  return elf[pos];
}

static uint16_t load16(long long pos)
{
  if (pos < 0 || pos+2 > elf_len) {
    fprintf(stderr,"%s misformatted, out-of-bounds read\n",fn);
    exit(111);
  }
  if (elf_endianness == 2) return uint16_t_load_bigendian((unsigned char *) elf+pos);
  return uint16_t_load((unsigned char *) elf+pos);
}

static uint32_t load32(long long pos)
{
  if (pos < 0 || pos+4 > elf_len) {
    fprintf(stderr,"%s misformatted, out-of-bounds read\n",fn);
    exit(111);
  }
  if (elf_endianness == 2) return uint32_t_load_bigendian((unsigned char *) elf+pos);
  return uint32_t_load((unsigned char *) elf+pos);
}

static uint64_t load64(long long pos)
{
  if (pos < 0 || pos+8 > elf_len) {
    fprintf(stderr,"%s misformatted, out-of-bounds read\n",fn);
    exit(111);
  }
  if (elf_endianness == 2) return uint64_t_load_bigendian((unsigned char *) elf+pos);
  return uint64_t_load((unsigned char *) elf+pos);
}

static uint64_t ato64(const char *s)
{
  uint64_t result = 0;
  while (*s) result = result*10 + (*s++-'0');
  return result;
}

static void elf_load(const char *fn,const char *bytes)
{
  long long pos;

  fd = open(fn,O_RDONLY);
  if (fd == -1) {
    fprintf(stderr,"unable to open %s: %s\n",fn,strerror(errno));
    exit(111);
  }
  elf_len = ato64(bytes);
  elf = malloc(elf_len);
  if (!elf) {
    fprintf(stderr,"unable to malloc %lld bytes: %s\n",elf_len,strerror(errno));
    exit(111);
  }
  for (pos = 0;pos < elf_len;) {
    long long r = read(fd,elf+pos,elf_len-pos);
    if (r < 0) {
      fprintf(stderr,"unable to read %s: %s\n",fn,strerror(errno));
      exit(111);
    }
    if (r == 0) {
      fprintf(stderr,"%s has been truncated, stopping\n",fn);
      exit(111);
    }
    pos += r;
  }

  if (load8(0) != 127 || load8(1) != 'E' || load8(2) != 'L' || load8(3) != 'F') {
    fprintf(stderr,"%s is not an ELF file, stopping\n",fn);
    exit(100);
  }
  elf_class = load8(4);
  if (elf_class != ELF_CLASS_32 && elf_class != ELF_CLASS_64) {
    fprintf(stderr,"%s has unknown class %d, stopping\n",fn,(int) elf_class);
    exit(100);
  }
  elf_endianness = load8(5);
  if (elf_endianness < 1 || elf_endianness > 2) {
    fprintf(stderr,"%s has unknown endianness %d, stopping\n",fn,(int) elf_endianness);
    exit(100);
  }
  if (load8(6) != 1) {
    fprintf(stderr,"%s has unknown header version %d, stopping\n",fn,(int) load8(6));
    exit(100);
  }
  // XXX: check osabi in load8(7)
  // XXX: check abi in load8(8)

  elf_type = load16(16);
  if (elf_type != 2) {
    fprintf(stderr,"%s is not a static executable, stopping\n",fn);
    exit(100);
  }

  elf_machine = load16(18);

  elf_version = load32(20);
  if (elf_version != 1) {
    fprintf(stderr,"%s has unknown version %d, stopping\n",fn,(int) elf_version);
    exit(100);
  }
  if (elf_class == ELF_CLASS_64) {
    elf_entry = load64(24);
    elf_phoff = load64(32);
    elf_flags = load32(48);
    elf_phentsize = load16(54);
    elf_phnum = load16(56);
  } else {
    elf_entry = load32(24);
    elf_phoff = load32(28);
    elf_flags = load32(36);
    elf_phentsize = load16(42);
    elf_phnum = load16(44);
  }
  if (!elf_entry) {
    fprintf(stderr,"%s has no entry point, stopping\n",fn);
    exit(100);
  }
  if (!elf_phoff) {
    fprintf(stderr,"%s has no program header, stopping\n",fn);
    exit(100);
  }
  if (!elf_phnum) {
    fprintf(stderr,"%s has no program-header entries, stopping\n",fn);
    exit(100);
  }
  if (elf_phentsize != ((elf_class == ELF_CLASS_64) ? 56 : 32)) {
    fprintf(stderr,"%s has wrong program-header-entry size %lld, stopping\n",fn,(long long) elf_phentsize);
    exit(100);
  }

  elf_minaddr = 0;
  elf_minaddr -= 1;

  for (pos = 0;pos < elf_phnum;++pos) {
    uint64_t h = elf_phoff+pos*elf_phentsize;
    uint64_t p_type = load32(h);
    uint64_t p_memsz = (elf_class == ELF_CLASS_64) ? load64(h+40) : load32(h+20);
    uint64_t p_vaddr = (elf_class == ELF_CLASS_64) ? load64(h+16) : load32(h+8);
    if (p_type == 1) { // PT_LOAD
      if (p_vaddr < elf_minaddr) elf_minaddr = p_vaddr;
      if (p_vaddr+p_memsz > elf_maxaddr) elf_maxaddr = p_vaddr+p_memsz;
    }
  }

  if (elf_minaddr+1 == 0) {
    fprintf(stderr,"%s has no segments to load, stopping\n",fn);
    exit(100);
  }

  elf_minaddr = elf_minaddr & ~(PAGE-1);
  elf_maxaddr = (elf_maxaddr+(PAGE-1)) & ~(PAGE-1);

  if (print_elf) {
    fprintf(stderr,"minaddr 0x%llx maxaddr 0x%llx\n",(long long) elf_minaddr,(long long) elf_maxaddr);
  }
}

// =====

#include <unicorn/unicorn.h>

static uc_engine *uc;
static long long handling_trap;
static uint32_t handling_trap_num;
static uc_hook trace_interrupt;
static uc_hook trace_syscall;

static long long sim_brk;
static long long sim_sp;
static long long sim_maps;
static char sim_iobuf[65536];

static void hook_syscall(uc_engine *uc,void *user_data)
{
  handling_trap = 1;
  handling_trap_num = 0;
  uc_emu_stop(uc);
}

static void hook_interrupt(uc_engine *uc,int32_t num,void *user_data)
{
  handling_trap = 2;
  handling_trap_num = num;
  uc_emu_stop(uc);
}

uint64_t sim_read(uint64_t fd,uint64_t buf,uint64_t len)
{
  uint64_t result;
  // XXX: look at fd
  if (len > sizeof sim_iobuf) len = sizeof sim_iobuf;
  result = fread(sim_iobuf,1,len,stdin);
  uc_mem_write(uc,buf,sim_iobuf,result);
  return result;
}

uint64_t sim_write(uint64_t fd,uint64_t buf,uint64_t len)
{
  // XXX: look at fd
  if (len > sizeof sim_iobuf) len = sizeof sim_iobuf;
  uc_mem_read(uc,buf,sim_iobuf,len);
  fwrite(sim_iobuf,len,1,stdout);
  fflush(stdout);
  return len;
}

// XXX: extend interface+implementation for flags etc.
uint64_t sim_mmap(uint64_t size)
{
  uint64_t result;
  if (size > 1073741824) { // XXX: configure
    fprintf(stderr,"huge mmap, aborting\n");
    exit(111);
  }
  size &= PAGE-1;
  size += PAGE;
  uc_err err = uc_mem_map(uc,sim_maps,size,UC_PROT_ALL);
  if (err != UC_ERR_OK) {
    fprintf(stderr,"uc_mem_map failed: %s\n",uc_strerror(err));
    exit(111);
  }
  result = sim_maps;
  sim_maps += size;
  return result;
}

static void trap(void)
{
  if (handling_trap == 1 && elf_machine == ELF_AMD64) {
    uint64_t rax,rdi,rsi,rdx,result = 0;
    uc_reg_read(uc,UC_X86_REG_RAX,&rax);
    uc_reg_read(uc,UC_X86_REG_RDI,&rdi);
    uc_reg_read(uc,UC_X86_REG_RSI,&rsi);
    uc_reg_read(uc,UC_X86_REG_RDX,&rdx);
    if (print_interrupts)
      fprintf(stderr,"syscall rax 0x%llx rdi 0x%llx rsi 0x%llx rdx 0x%llx\n",(long long) rax,(long long) rdi,(long long) rsi,(long long) rdx);

    switch(rax) {
      case 0x3c: // exit code
        exit(rdi);
      case 0x00: // read fd, buf, count
        result = sim_read(rdi,rsi,rdx);
        break;
      case 0x01: // write fd, buf, count
        result = sim_write(rdi,rsi,rdx);
        break;
    }
    uc_reg_write(uc,UC_X86_REG_RAX,&result);
  } else if (handling_trap == 2 && elf_machine == ELF_X86) {
    uint32_t eax,ebx,ecx,edx,result = 0;
    uc_reg_read(uc,UC_X86_REG_EAX,&eax);
    uc_reg_read(uc,UC_X86_REG_EBX,&ebx);
    uc_reg_read(uc,UC_X86_REG_ECX,&ecx);
    uc_reg_read(uc,UC_X86_REG_EDX,&edx);
    if (print_interrupts)
      fprintf(stderr,"interrupt %lld eax 0x%llx ebx 0x%llx ecx 0x%llx edx 0x%llx\n",(long long) handling_trap_num, (long long) eax,(long long) ebx,(long long) ecx,(long long) edx);

    if (handling_trap_num != 128) exit(111); // not a syscall

    switch(eax) {
      case 0x01: // exit code
      case 0xfc: // exit_group code
        exit(ebx);
      case 0x03: // read fd, buf, count
        result = sim_read(ebx,ecx,edx);
        break;
      case 0x04: // write fd, buf, count
        result = sim_write(ebx,ecx,edx);
        break;
      case 0xc0: // mmap2
        result = sim_mmap(ebx);
        break;
    }
    uc_reg_write(uc,UC_X86_REG_EAX,&result);
  } else if (handling_trap == 2 && elf_machine == ELF_ARM64) {
    uint64_t x8,x0,x1,x2,result = 0;
    uc_reg_read(uc,UC_ARM64_REG_X0,&x0);
    uc_reg_read(uc,UC_ARM64_REG_X1,&x1);
    uc_reg_read(uc,UC_ARM64_REG_X2,&x2);
    uc_reg_read(uc,UC_ARM64_REG_X8,&x8);
    if (print_interrupts)
      fprintf(stderr,"interrupt %lld x8 0x%llx x0 0x%llx x1 0x%llx x2 0x%llx\n",(long long) handling_trap_num,(long long) x8,(long long) x0,(long long) x1,(long long) x2);

    if (handling_trap_num != 2) exit(111); // not a syscall
  
    switch(x8) {
      case 0x5d: // exit code
      case 0x5e: // exit_group code
        exit(x0);
      case 0x3f: // read fd, buf, count
        result = sim_read(x0,x1,x2);
        break;
      case 0x40: // write fd, buf, count
        result = sim_write(x0,x1,x2);
        break;
      case 0xde: // mmap
        result = sim_mmap(x1);
        break;
    }
    uc_reg_write(uc,UC_ARM64_REG_X0,&result);
  } else if (handling_trap == 2 && elf_machine == ELF_ARM32) {
    uint32_t r7,r0,r1,r2,result = 0;
    uc_reg_read(uc,UC_ARM_REG_R0,&r0);
    uc_reg_read(uc,UC_ARM_REG_R1,&r1);
    uc_reg_read(uc,UC_ARM_REG_R2,&r2);
    uc_reg_read(uc,UC_ARM_REG_R7,&r7);

    if (print_interrupts)
      fprintf(stderr,"interrupt %lld r7 0x%llx r0 0x%llx r1 0x%llx r2 0x%llx\n",(long long) handling_trap_num,(long long) r7,(long long) r0,(long long) r1,(long long) r2);

    if (handling_trap_num != 2) exit(111); // not a syscall

    switch(r7) {
      case 0x01: // exit code
        exit(r0);
      case 0x03: // read fd, buf, count
        result = sim_read(r0,r1,r2);
        break;
      case 0x04: // write fd, buf, count
        result = sim_write(r0,r1,r2);
        break;
    }

    uc_reg_write(uc,UC_ARM_REG_R0,&result);
  } else if (handling_trap == 2 && elf_machine == ELF_SPARC32) {
    uint32_t g1,o0,o1,o2,reg,result = 0;
    uc_reg_read(uc,UC_SPARC_REG_G1,&g1);
    uc_reg_read(uc,UC_SPARC_REG_O0,&o0);
    uc_reg_read(uc,UC_SPARC_REG_O1,&o1);
    uc_reg_read(uc,UC_SPARC_REG_O2,&o2);

    if (print_interrupts)
      fprintf(stderr,"interrupt %lld g1 0x%llx o0 0x%llx o1 0x%llx o2 0x%llx\n",(long long) handling_trap_num,(long long) g1,(long long) o0,(long long) o1,(long long) o2);

    if (handling_trap_num != 144) exit(111); // not a syscall

    switch (g1) {
      case 0x01: // exit
        exit(o0);
      case 0x03: // read fd, buf, count
        result = sim_read(o0,o1,o2);
        break;
      case 0x04: // write fd, buf, count
        result = sim_write(o0,o1,o2);
        break;
      case 0x47: // mmap
        result = sim_mmap(o1);
        break;
    }

    uc_reg_read(uc,UC_SPARC_REG_PSR,&reg);
    reg &= ~0x00100000; // carry bit of PSR would indicate error
    uc_reg_write(uc,UC_SPARC_REG_PSR,&reg);
    uc_reg_read(uc,UC_SPARC_REG_PC,&reg);
    reg += 4; // move past the trap instruction
    uc_reg_write(uc,UC_SPARC_REG_PC,&reg);

    uc_reg_write(uc,UC_SPARC_REG_O0,&result);
  } else {
    if (print_interrupts) {
      if (handling_trap == 1)
        fprintf(stderr,"syscall, unhandled machine type\n");
      else
        fprintf(stderr,"interrupt %lld, unhandled machine type\n",(long long) handling_trap_num);
    }
    exit(111);
  }
}

static void emulate(void)
{
  uc_err err;
  uint64_t startinsn;
  uint32_t startinsn32;
  long long pos;

  switch(elf_machine) {
    case ELF_SPARC32:
      err = uc_open(UC_ARCH_SPARC,UC_MODE_SPARC32|UC_MODE_BIG_ENDIAN,&uc); break;
    case ELF_X86:
      err = uc_open(UC_ARCH_X86,UC_MODE_32,&uc); break;
    case ELF_AMD64:
      err = uc_open(UC_ARCH_X86,UC_MODE_64,&uc); break;
    case ELF_ARM32:
      err = uc_open(UC_ARCH_ARM,UC_MODE_ARM,&uc); break;
    case ELF_ARM64:
      err = uc_open(UC_ARCH_ARM64,UC_MODE_ARM,&uc); break;
    case ELF_PPC64:
      // XXX: as of 2025-02, unicorn does not support PPC64
      // err = uc_open(UC_ARCH_PPC,UC_MODE_PPC64,&uc); break;
    default:
      fprintf(stderr,"unsupported machine type %lld\n",(long long) elf_machine);
      exit(100);
  }
  if (err != UC_ERR_OK) {
    fprintf(stderr,"uc_open failed: %s\n",uc_strerror(err));
    exit(111);
  }

  sim_brk = elf_maxaddr;
  sim_sp = sim_brk+1048576; // XXX: allow configuration
  sim_maps = sim_sp+1048576; // XXX: allow configuration

  if (elf_minaddr >= PAGE) {
    err = uc_mem_map(uc,0,PAGE,UC_PROT_ALL);
    if (err != UC_ERR_OK) {
      fprintf(stderr,"uc_mem_map failed: %s\n",uc_strerror(err));
      exit(111);
    }
  }

  err = uc_mem_map(uc,elf_minaddr,sim_maps-elf_minaddr,UC_PROT_ALL);
  if (err != UC_ERR_OK) {
    fprintf(stderr,"uc_mem_map failed: %s\n",uc_strerror(err));
    exit(111);
  }

  for (pos = 0;pos < elf_phnum;++pos) {
    uint64_t h = elf_phoff+pos*elf_phentsize;
    uint64_t p_type = load32(h);
    uint64_t p_offset = (elf_class == ELF_CLASS_64) ? load64(h+8) : load32(h+4);
    uint64_t p_vaddr = (elf_class == ELF_CLASS_64) ? load64(h+16) : load32(h+8);
    uint64_t p_filesz = (elf_class == ELF_CLASS_64) ? load64(h+32) : load32(h+16);
    if (p_type == 1) { // PT_LOAD
      if (p_offset > elf_len || p_filesz > elf_len-p_offset) {
        fprintf(stderr,"segment outside binary, stopping\n");
        exit(111);
      }
      err = uc_mem_write(uc,p_vaddr,elf+p_offset,p_filesz);
      if (err != UC_ERR_OK) {
        fprintf(stderr,"uc_mem_write failed: %s\n",uc_strerror(err));
        exit(111);
      }
    }
  }

  // XXX: could pass args along here
  if (elf_class == ELF_CLASS_64) {
    pos = 0;
    uc_mem_write(uc,sim_sp-8,&pos,8);
    uc_mem_write(uc,sim_sp-16,&pos,8);
    uc_mem_write(uc,sim_sp-32,&pos,8);
    pos = sim_sp-8;
    uc_mem_write(uc,sim_sp-24,&pos,8);
    uc_mem_write(uc,sim_sp-40,&pos,8);
    pos = 1;
    uc_mem_write(uc,sim_sp-48,&pos,8);
    sim_sp -= 48;
  } else {
    pos = 0;
    uc_mem_write(uc,sim_sp-4,&pos,4);
    uc_mem_write(uc,sim_sp-8,&pos,4);
    uc_mem_write(uc,sim_sp-16,&pos,4);
    pos = sim_sp-4;
    uc_mem_write(uc,sim_sp-12,&pos,4);
    uc_mem_write(uc,sim_sp-20,&pos,4);
    pos = 1;
    uc_mem_write(uc,sim_sp-24,&pos,4);
    sim_sp -= 24;
  }

  uint32_t sp32 = sim_sp;
  uint32_t zero = 0;

  switch(elf_machine) {
    case ELF_X86:
      uc_reg_write(uc,UC_X86_REG_ESP,&sim_sp); break;
    case ELF_AMD64:
      uc_reg_write(uc,UC_X86_REG_RSP,&sim_sp); break;
    case ELF_ARM32:
      uc_reg_write(uc,UC_ARM_REG_SP,&sp32); break;
    case ELF_ARM64:
      uc_reg_write(uc,UC_ARM64_REG_SP,&sim_sp); break;
    case ELF_SPARC32:
      uc_reg_write(uc,UC_SPARC_REG_SP,&sp32); break;
      uc_reg_write(uc,UC_SPARC_REG_L7,&zero); break;
  }

  if (elf_machine == ELF_AMD64) {
    uc_hook_add(uc,&trace_syscall,UC_HOOK_INSN,(void *)hook_syscall,0,1,0,UC_X86_INS_SYSCALL);
  } else {
    uc_hook_add(uc,&trace_interrupt,UC_HOOK_INTR,(void *)hook_interrupt,0,1,0);
  }

  startinsn = elf_entry;
  
  for (;;) {
    handling_trap = 0;
    handling_trap_num = 0;
    err = uc_emu_start(uc,startinsn,elf_maxaddr,0,0);
    if (err != UC_ERR_OK) {
      fprintf(stderr,"uc_emu_start failed: %s\n",uc_strerror(err));
      exit(111);
    }
    if (!handling_trap) break;
    trap();
    switch(elf_machine) {
      case ELF_X86:
        uc_reg_read(uc,UC_X86_REG_EIP,&startinsn32);
        startinsn = startinsn32; break;
      case ELF_AMD64:
        uc_reg_read(uc,UC_X86_REG_RIP,&startinsn); break;
      case ELF_ARM32:
        uc_reg_read(uc,UC_ARM_REG_PC,&startinsn32);
        startinsn = startinsn32; break;
      case ELF_ARM64:
        uc_reg_read(uc,UC_ARM64_REG_PC,&startinsn); break;
      case ELF_SPARC32:
        uc_reg_read(uc,UC_SPARC_REG_PC,&startinsn32);
        startinsn = startinsn32; break;
    }
  }
}

int main(int argc,char **argv)
{
  if (!argv[1] || !argv[2]) {
    fprintf(stderr,"need path to binary and number of bytes to read\n");
    exit(100);
  }
  elf_load(argv[1],argv[2]);
  emulate();
  return 0;
}