/**
 ** sem.c
 **
 ** Copyright 1995 by Kurt Konolige
 **
 ** The author hereby grants to SRI permission to use this software.
 ** The author also grants to SRI permission to distribute this software
 ** to schools for non-commercial educational use only.
 **
 ** The author hereby grants to other individuals or organizations
 ** permission to use this software for non-commercial
 ** educational use only.  This software may not be distributed to others
 ** except by SRI, under the conditions above.
 **
 ** Other than these cases, no part of this software may be used or
 ** distributed without written permission of the author.
 **
 ** Neither the author nor SRI make any representations about the 
 ** suitability of this software for any purpose.  It is provided 
 ** "as is" without express or implied warranty.
 **
 ** Kurt Konolige
 ** Senior Computer Scientist
 ** SRI International
 ** 333 Ravenswood Avenue
 ** Menlo Park, CA 94025
 ** E-mail:  konolige@ai.sri.com
 **
 **/


/**************************************************************************
 *   These functions construct a semantic structure
 *    representing the syntactic constructs of C.
 *   
 *   Reductions on the semantic structure yield one
 *    type of optimization.
 *
 *   The semantic structure can be coded into either PCODES
 *    or native operations.
 **************************************************************************/


#include "core.h"

int sem_debug = 0;

char *
name_of_eop(eop op)
{
  switch(op)
    {
    case plus: return("+"); break;
    case minus: return("-"); break;
    case times: return("*"); break;
    case divide: return("/"); break;
    case modulo: return("%"); break;
    case neg: return("-"); break;
    case equal: return("="); break;
    case plusequal: return("+="); break;
    case minusequal: return("-="); break;
    case timesequal: return("*="); break;
    case divideequal: return("/="); break;
    case bitandequal: return("&="); break;
    case bitorequal: return("|="); break;
    case bitxorequal: return("^="); break;
    case lshiftequal: return("<<="); break;
    case rshiftequal: return(">>="); break;
    case eequal: return("=="); break;
    case nequal: return("!="); break;
    case gtequal: return(">="); break;
    case ltequal: return("<="); break;
    case lessthan: return("<"); break;
    case greaterthan: return(">"); break;
    case logand: return("&&"); break;
    case logior: return("||"); break;
    case lognot: return("!"); break;
    case bitior: return("|"); break;
    case bitxor: return("^"); break;
    case bitand: return("&"); break;
    case bitcomp: return("~"); break;
    case lshift: return("<<"); break;
    case rshift: return(">>"); break;
    case float_to_int: return("float -> int"); break;
    case float_to_long: return("float -> long"); break;
    case int_to_float: return("int -> float"); break;
    case int_to_long: return("int -> long"); break;
    case int_to_pointer: return("int -> pointer"); break;
    case long_to_int: return("long -> int"); break;
    case long_to_float: return("long -> float"); break;
    case pointer_to_int: return("pointer -> int"); break;
    case ssin: return("sin"); ;
    case scos: return("cos"); ;
    case stan: return("tan"); ;
    case ssqrt: return("sqrt"); ;
    case satan: return("atan"); ;
    case slog10: return("log10"); ;
    case sloge: return("loge"); ;
    case sexp10: return("exp10"); ;
    case sexpe: return("exp"); ;
    case prntf: return("printf"); ;
    default: return("NONE");
    }
}


/* stack of s-expressions */

sexp * sem_stack[20];		/* holds pushed s-expressions */
int sem_stack_n = 0;		/* pointer */

Int
sexp_stack_empty()
{
  return sem_stack_n == 0;
}

void
push_sexp(sexp *exp)
{
  sem_stack[sem_stack_n++] = exp;
}

sexp *
pop_sexp()
{
  if (sem_stack_n == 0) die(("Too many pop's in semantic routines\n"));
  return(sem_stack[--sem_stack_n]);
}

void
sexp_stack_switch()		/* switches top two entries */
{
  sexp *s;
  if (sem_stack_n > 1)
    {
      s = sem_stack[sem_stack_n-1];
      sem_stack[sem_stack_n-1] = sem_stack[sem_stack_n-2];
      sem_stack[sem_stack_n-2] = s;
    }
}

sexp *
top_sexp(Int n)			/* return top nth sexp from stack */
{
  if (n <sem_stack_n)
    return(sem_stack[sem_stack_n - n - 1]);
  else 
    return(NULL);
}


/* argument coercion */

void				/* replace arg with coerced arg */
coerce_replace(sexp *a, eop op, Type *type)
{
  sexp *new = (sexp *)malloc(sizeof(sexp));
  *new = *a;
  a->exp = unop;
  a->op = op;
  a->type = type;
  a->arg1 = new;
}


int				/* coerce to type */
coerce_type(sexp *a, Type *type)
{
  Type *t = a->type;
  etype e = a->exp;
  if (pass == 0 && (type_id(type) == undef_id || type_id(a->type) == undef_id)) 
    return 1;

  if (e == ref || type_id(t) == pointer_id) /* let these be ints */
    {				/* really should check type of pointer here... */
      if (type_id(type) == int_id  || type_id(type) == pointer_id)
	return 1;
      else 
	return 0;
    }

  else
    {
      switch(type_id(type))
	{
    case char_id:
    case int_id:
      switch(type_id(t))
	{
	case char_id:
	  if (e == cnst)	/* constants are integers */
	    a->type = type_Int(); 
	case int_id: return(1);
	case long_id:
	  if (e == cnst)	/* do it inline */
	    {
	      a->type = type_Int();
	      (a->val).i = (int)(a->val).l;
	      return(1);
	    }
	  else
	    {
	      coerce_replace(a,long_to_int,type_Int());
	      return(1);
	    }
	  break;
	case float_id:
	  if (e == cnst)	/* do it inline */
	    {
	      a->type = type_Int();
	      (a->val).i = (int)(a->val).d;
	      return(1);
	    }
	  else
	    {
	      coerce_replace(a,float_to_int,type_Int());
	      return(1);
	    }
	  break;
	default:
	  return(0);
	}

    case long_id:
      switch(type_id(t))
	{
	case long_id: return(1);
	case char_id:
	case int_id:
	  if (e == cnst)	/* do it inline */
	    {
	      a->type = type_Long();
	      (a->val).l = (long)(a->val).i;
	      return(1);
	    }
	  else
	    {
	      coerce_replace(a,int_to_long,type_Long());
	      return(1);
	    }
	  break;
	case float_id:
	  if (e == cnst)	/* do it inline */
	    {
	      a->type = type_Long();
	      (a->val).l = (long)(a->val).d;
	      return(1);
	    }
	  else
	    {
	      coerce_replace(a,float_to_long,type_Long());
	      return(1);
	    }
	  break;
	default:
	  return(0);
	}

    case float_id:
      switch(type_id(t))
	{
	case float_id: return(1);
	case char_id:
	case int_id:
	  if (e == cnst)	/* do it inline */
	    {
	      a->type = type_Float();
	      (a->val).d = (double)(a->val).i;
	      return(1);
	    }
	  else
	    {
	      coerce_replace(a,int_to_float,type_Float());
	      return(1);
	    }
	  break;
	case long_id:
	  if (e == cnst)	/* do it inline */
	    {
	      a->type = type_Float();
	      (a->val).d = (double)(a->val).l;
	      return(1);
	    }
	  else
	    {
	      coerce_replace(a,long_to_float,type_Float());
	      return(1);
	    }
	  break;
	default:
	  return(0);
	}

    case pointer_id:
      switch(type_id(t))
	{
	case char_id:
	case int_id:
	  return(1);		/* say it's ok, Joe */
	default:
	  return(0);
	}
      break;

    default:
      return(0);

    }
    }
}


int
coerce_args(sexp *a1,sexp *a2)	/* make them the same by promotion */
{
  Type *t1 = a1->type;
  Type *t2 = a2->type;
  if (type_id(t1) == type_id(t2)) return 1;
  if (pass == 0 && (type_id(t1) == undef_id || type_id(t2) == undef_id)) return 1;
  switch(type_id(t1))
    {
    case char_id:
    case int_id:
      if (type_id(t2) == char_id) 
	return(coerce_type(a2,t1));
      else
	return(coerce_type(a1,t2));
    case long_id:
      if (type_id(t2) == int_id || type_id(t2) == char_id)
	return(coerce_type(a2,t1));
      else return(0);
    case float_id:
      return(coerce_type(a2,t1));
    }
  return(0);
}


/* semantic reductions */

int
bop_reduce(sexp *a1, sexp *a2, eop op)
{
  etype e1 = a1->exp;
  etype e2 = a2->exp;
  sexp *a;
  
  if (e1 == cnst && e2 == cnst)
    {
      a = (sexp *)malloc(sizeof(sexp));
      switch(op)
	{
	case plus:
	  switch(type_id(a1->type))
	    {
	    case int_id:
	      (a->val).i = (a1->val).i + (a2->val).i;
	      break;
	    case long_id:
	      (a->val).l = (a1->val).l + (a2->val).l;
	      break;
	    case float_id:
	      (a->val).d = (a1->val).d + (a2->val).d;
	      break;
	    default:
	      return(0);
	    }
	  break;

	case minus:
	  switch(type_id(a1->type))
	    {
	    case int_id:
	      (a->val).i = (a1->val).i - (a2->val).i;
	      break;
	    case long_id:
	      (a->val).l = (a1->val).l - (a2->val).l;
	      break;
	    case float_id:
	      (a->val).d = (a1->val).d - (a2->val).d;
	      break;
	    default:
	      return(0);
	    }
	  break;

	case times:
	  switch(type_id(a1->type))
	    {
	    case int_id:
	      (a->val).i = (a1->val).i * (a2->val).i;
	      break;
	    case long_id:
	      (a->val).l = (a1->val).l * (a2->val).l;
	      break;
	    case float_id:
	      (a->val).d = (a1->val).d * (a2->val).d;
	      break;
	    default:
	      return(0);
	    }
	  break;

	case divide:
	  switch(type_id(a1->type))
	    {
	    case int_id:
	      (a->val).i = (a1->val).i / (a2->val).i;
	      break;
	    case long_id:
	      (a->val).l = (a1->val).l / (a2->val).l;
	      break;
	    case float_id:
	      (a->val).d = (a1->val).d / (a2->val).d;
	      break;
	    default:
	      return(0);
	    }
	  break;
	  
	default:
	  return(0);
	  
	}
      a->exp = cnst;
      a->type = a1->type;
      a->op = minus;
      push_sexp(a);
      return(1);
    }
  return(0);
}


int
unop_reduce(sexp *a1, eop op)
{
  etype e1 = a1->exp;
  sexp *a;
  
  if (e1 == cnst)
    {
      a = (sexp *)malloc(sizeof(sexp));
      switch(op)
	{
	case neg:
	  switch(type_id(a1->type))
	    {
	    case int_id:
	      (a->val).i = -(a1->val).i;
	      break;
	    case long_id:
	      (a->val).l = -(a1->val).l;
	      break;
	    case float_id:
	      (a->val).d = -(a1->val).d;
	      break;
	    default:
	      return(0);
	    }
	  break;

	default:
	  return(0);
	}

      a->exp = cnst;
      a->type = a1->type;
      a->op = minus;		/* just in case... */
      push_sexp(a);
      return(1);
    }
  return(0);
}


void				/* only reduction is minus const -> plus -const */
assign_reduce(sexp *a)
{
  sexp *s = a->arg2;

  if (s->exp == cnst && type_id(a->type) == int_id && a->op == minusequal)
    {
      if (sem_debug)
	printf("Reducing -= to +=\n");
      a->op = plusequal;
      s->val.i = -s->val.i;
    }
}



/* create semantic expressions */

sexp *
create_binary_sexp(etype et, eop op, Type *type, sexp *arg1, sexp *arg2)
{
  sexp *s = (sexp *)malloc(sizeof(sexp));
  s->exp = et;
  s->op = op;
  s->type = type;
  s->arg1 = arg1;
  s->arg2 = arg2;
  return(s);
}

void				/* create it using stack, push it */
bop_sexp(eop op)
{
  sexp *a2 = pop_sexp();
  sexp *a1 = pop_sexp();
  if (sem_debug) 
    printf("Binary op %d\n",op);
  if (!coerce_args(a1,a2))
    user_error(("Incompatible args in %d %s %d\n",
		a1->type, name_of_eop(op), a2->type));
  if (bop_reduce(a1,a2,op))	/* we might reduce it to a constant */
    { 
      free(a1); free(a2);
    }
  else
    push_sexp(create_binary_sexp(bop, op, 
				 (type_id(a1->type) == char_id) ? type_Int() : a1->type, 
				 a1, a2));
}


/* Unary arithmetic or logical operations */


void				/* create it using stack, push it */
unop_sexp(eop op)
{
  sexp *a1 = pop_sexp();
  if (sem_debug) 
    printf("Unary op %d\n",op);
  if (unop_reduce(a1,op))	/* we might reduce it to a constant */
    free(a1);
  else
    push_sexp(create_binary_sexp(unop, op, 
				 (type_id(a1->type) == char_id) ? type_Int() : a1->type, 
				 a1, NULL));
}

/* Nilary operations */


void				/* create it using stack, push it */
nilop_sexp(eop op, Type *type)
{
  if (sem_debug) 
    printf("Nilary op %d\n",op);
  push_sexp(create_binary_sexp(nilop, op, type, NULL, NULL));
}



/* Casting operation */

void
cast_sexp(Type *type)
{
  sexp *a1 = top_sexp(0);
  if (sem_debug)
    printf("Cast op %d\n",type_id(type));
  if (!coerce_type(a1,type))
    user_error(("Attempt to coerce type %s to type %s\n",
		type_name(a1->type), type_name(type)));

}



/* Constant creation */

sexp *
create_const_sexp(Type *type, eval val)
{
  sexp *s = (sexp *)malloc(sizeof(sexp));
  s->exp = cnst;
  s->type = type;
  s->val = val; 
  s->op = plus;			/* say we haven't stored it yet, for strings */
  return(s);
}

void				/* create and push it */
const_sexp(Type *type, eval val)
{
  char *s;
  if (sem_debug) 
    switch(type_id(type))
      {
      case int_id:
	printf("Pushing constant %d\n",val.i);
	break;
      case long_id:
	printf("Pushing constant %l\n",val.l);
	break;
      case float_id:
	printf("Pushing constant %d\n",val.d);
	break;
      case pointer_id:
	printf("Pushing constant %s\n",(char *)val.p);
	break;
      }
  push_sexp(create_const_sexp(type, val));
}


/* Symbol/Pointer creation */

sexp *
create_ref_sexp(Symbol *sym)
{
  sexp *s = (sexp *)malloc(sizeof(sexp));
  s->exp = ref;
  s->val.p = sym;
  if (symtab_get(sym))
    s->type = value_type(symtab_get(sym));
  else
    s->type = type_Undef();
  return(s);
}

void				/* create and push it */
ref_sexp(Symbol *sym)
{
  sexp *s;

  if (sem_debug) 
    printf("Pushing symbol ref %s\n",symbol_name(sym));
  s = create_ref_sexp(sym);
  push_sexp(s);
}


/* Array references */


sexp *
create_aref_sexp(sexp *l, sexp *a)
{
  sexp *s = (sexp *)malloc(sizeof(sexp));
  s->exp = aref;
  s->arg1 = l;
  s->arg2 = a;
  s->type = type_Array_elemtype(type_Pointer_deref(l->type));
  return(s);
}


void				/* create and push it */
array_sexp()
{
  sexp *s, *a;

  a = pop_sexp();
  s = pop_sexp();
  if (sem_debug) 
    printf("Pushing aref \n");
  s = create_aref_sexp(s,a);
  push_sexp(s);
}



/* Take value of pointer/reference */

sexp *
create_val_sexp(Type *t, sexp *a)
{
  sexp *s = (sexp *)malloc(sizeof(sexp));
  s->exp = val;
  s->arg1 = a;
  s->type = t;
  return(s);
}



void				/* Add a dereference */
val_sexp()
{
  sexp *a1 = top_sexp(0);
  Symbol *sym = (Symbol *)a1->val.p;
  Type *t = a1->type;

  if (sem_debug)
    printf("Deref op \n");
  if (a1->exp == ref)		/* variable reference */
    {
      if (symtab_get(sym))
	{
	  Value *value = symtab_get(sym);

	  if (value->loctype == constant) /* have a constant variable */
	    {
	      if (sem_debug)
		printf("Reducing constant symbol\n");
	      a1->exp = cnst;
	      a1->type = value_type(value);
	      a1->val = ((sexp *)value->val)->val;
	      a1->op = minus;	/* say we don't need to store it, just in case... */
	    }
	  else
	    {
	      a1->exp = val;
	      a1->type = value_type(value);
	      a1->arg1 = NULL;
	    }
	}
      else			/* an undefined variable, must be an int... */
	{
	  a1->exp = val;
	  a1->type = pass == 0 ? type_Undef() : type_Undef();
	  a1->arg1 = NULL;
	}
    }
  else if (a1->exp == aref)
    {
	  a1->exp = aval;
    }
  else
    {
      pop_sexp();
				/* *** WE ALLOW INTs TO BE DEREFERENCED...  */
      if (type_id(t) == int_id)
	push_sexp(create_val_sexp(type_Int(), a1));
      else if (type_id(t) == pointer_id)
	push_sexp(create_val_sexp(type_Pointer_deref(t), a1));
      else
	{
	  user_error(("Attempt to dereference a non-pointer type %s\n",
		      type_name(t)));
	  push_sexp(create_val_sexp(type_Undef(), a1));
	}
    }
}

/* PEEK operation */

sexp *
create_peek_sexp(sexp *a, Int size)
{
  sexp *s = (sexp *)malloc(sizeof(sexp));
  s->exp = peek;
  s->arg1 = a;
  s->type = type_Int();
  s->op = size;
  return(s);
}


void				/* add a PEEK */
peek_sexp(Int size)
{
  sexp *a1 = pop_sexp();

  if (sem_debug)
    printf("Peek op \n");
  push_sexp(create_peek_sexp(a1,size));
}


/* Assignments */

void				/* takes lvalue and expression, returns expression */
assign_sexp(eop op, Int expflag)
{
  sexp *l, *s;		/* lvalue and expression */

  if (expflag) 
    { s = pop_sexp(); l = pop_sexp(); }	/* lvalue */
  else 
    { l = pop_sexp(); s = create_const_sexp(type_Int(), (eval)(Int)1); }

  if (sem_debug)
    printf("Assign op \n");
  if (!coerce_type(s,l->type))
    user_error(("Attempt to coerce type %s to type %s\n",
		type_name(s->type), type_name(l->type)));
  push_sexp(create_binary_sexp(assign, op, l->type, l, s));
  assign_reduce(top_sexp(0));
}

void				/* for post-expressions x++, x-- */
asspost_sexp(eop op)
{
  sexp *s;
  assign_sexp(op,0);
  if (op == minusequal) op = plus;
  else op = minus;
  const_sexp(type_Int(), (eval)(Int)1);
  bop_sexp(op);
}

void				
const_assign_sexp(Symbol *sym)	/* Assignment to constant symbols */
{
  Value *value = symtab_get(sym);
  sexp *s = pop_sexp();
  
  if (!coerce_type(s,value_type(value)))
    user_error(("Attempt to coerce type %s to type %s\n",
		type_name(s->type), type_name(value_type(value))));
  else if (s->exp != cnst)		/* not a constant, barf */
    user_error(("No constant in constant declaration of %s\n",
		symbol_name(sym)));
  else
    value->val = s;
}


/* Jump expressions */

void				/* jump to l1, then leave label l2 */
jump_sexp(Symbol *l1, Symbol *l2)
{
  sexp *s = (sexp *)malloc(sizeof(sexp));
  s->exp = jump;
  s->arg1 = (void *)l1;
  s->arg2 = (void *)l2;
  push_sexp(s);
}


/* Function calls */

sexp *
gather_args()			/* creates a list of arguments to next call mark */
{
  sexp *ret = NULL;
  sexp *s;
  while ((s = pop_sexp())->exp != callmark) 
    {
      s->next = ret;
      ret = s;
    }
  return ret;
}


void				/* put a marker on the sexp stack */
call_arg_sexp()
{
  sexp *s = (sexp *)malloc(sizeof(sexp));
  s->exp = callmark;
  push_sexp(s);
}

void				/* function call */
call_sexp(Symbol *sym)
{
  sexp *s = (sexp *)malloc(sizeof(sexp));
  s->exp = call;
  s->val.p = sym;
  s->op = plus;
  if (symtab_get(sym))
    s->type = type_Func_return_type(value_type(symtab_get(sym)));
  else
    s->type = type_Undef();
  s->arg1 = gather_args();
  push_sexp(s);
}

/* Printf */


void
printf_sexp()			/* expects call marker on sexp stack */
{
  sexp *s = (sexp *)malloc(sizeof(sexp));
  s->exp = call;
  s->op = prntf;
  s->arg1 = gather_args();
  push_sexp(s);
}

/* Start_process */

void
start_process_sexp()
{
  sexp *s = (sexp *)malloc(sizeof(sexp));
  s->exp = call;
  s->op = start_process;
  s->arg1 = gather_args();
  s->type = type_Int();
  push_sexp(s);
}

void
kill_process_sexp()
{
  sexp *s = (sexp *)malloc(sizeof(sexp));
  s->exp = unop;
  s->op = kill_process;
  s->arg1 = pop_sexp();
  s->type = type_Int();
  push_sexp(s);
}


/* Procedure reporting */

Int report_flag = 1;
Int start_ptr;

void
start_procedure_stats()
{
  if (pass == 1)
    start_ptr = code_current;
}

void
report_procedure_stats(Symbol *sym)
{
  if (pass == 1)
    printf("Procedure %s used %d bytes\n", 
	   symbol_name(sym),
	   (Int)(code_current - start_ptr));
}
