import re, random, copy
import board
"""
Team Sealab 2021
jack@performancedrivers.com
bob@syristatides.com

This module parses programs,
class Program will parse legal programs, but also supports lots of syntactic sugar like labels
class SuperProgram is a super-set of Program that allows for inspection of the internal
                   representation of things, like doing Assert() and Print()
"""

def try_intify(text):
  try:
    return int(text)
  except:
    return text

def mk_mark(i, st):
  def mark(bob, ant):
    if (ant.color == board.RED_HILL):
      marker = 'redmark'
    else:
      marker = 'blkmark'
    bob[ant.pos][marker][i] = 1
    ant.state = st
    #bob.set_square(ant.pos, bob[ant.pos])
    return
  return mark
def mk_unmark(i, st):
  def unmark(bob, ant):
    if (ant.color == board.RED_HILL):
      marker = 'redmark'
    else:
      marker = 'blkmark'
    bob[ant.pos][marker][i] = 0
    ant.state = st
    #bob.set_square(ant.pos, bob[ant.pos])
    return
  return unmark
def mk_pickup(st1, st2):
  def pickup(bob, ant):
    info = bob[ant.pos]
    if (info['food'] and not ant.has_food):
      info['food'] -= 1
      ant.state = st1
      ant.has_food = 1
    else:
      ant.state = st2
    return
  return pickup
def mk_drop(st):
  def drop(bob, ant):
    info = bob[ant.pos]
    if (ant.has_food):
      info['food'] += 1
    ant.has_food = 0
    ant.state = st
    return
  return drop
def mk_turn(dir, st):
  if (dir == 'Left'):
    return mk_turn_left(st)
  elif (dir == 'Right'):
    return mk_turn_right(st)
  else:
    raise Exception("Unk turn")
def mk_turn_left(st):
  def turn_left(bob, ant):
    ant.dir = board.dir_left(ant.dir)
    ant.state = st
    return
  return turn_left
def mk_turn_right(st):
  def turn_right(bob, ant):
    old_dir = ant.dir
    ant.dir = board.dir_right(ant.dir)
    new_dir = ant.dir
    ant.state = st
    return
  return turn_right
def mk_move(st1, st2):
  def move(bob, ant):
    goto = bob.rel_pos(ant.pos, ant.dir)
    old_pos = ant.pos
    new_pos = goto
    if (bob[new_pos]['ant']):
      ant.state = st2
    elif (bob[goto]['type'] == board.CLEAR):
      bob[ant.pos]['ant'] = None
      ant.pos = goto
      bob[goto]['ant'] = ant
      ant.resting += 14
      ant.state = st1
      check_dead_ants(bob, ant)
      #bob.set_square(old_pos, bob[old_pos])
      #bob.set_square(new_pos, bob[new_pos])
      #print "type dir, bef, aft", (ant.type, ant.dir, old_pos, new_pos)
    else:
      ant.state = st2
    return
  return move
def check_dead_ants(bob, ant):
  return

class Rand(object):
  def __init__(self, seed):
    self.s0 = seed
    self.si_list = [seed]
    self.call_cnt = 0
    return
  def make_x(self):
    newx = (self[self.call_cnt+4] / 65536) % 16384
    return newx
  def __call__(self, maxint):
    self.call_cnt += 1
    ans = self.make_x()
    return ans % maxint  
  def __getitem__(self, wanti):
    while (wanti >= len(self.si_list)):
      i = len(self.si_list)
      self.si_list.append(self.si_list[i-1] * 22695477 + 1)
    return self.si_list[wanti]

#random.seed(123567)
G_rand = Rand(12345)    
def mk_flip(maxint, st1, st2):
  if (st1 == st2):
    def skip_flip(bob, ant):
      ant.state = st1
      return
    return skip_flip
  #def flip(bob, ant):
  #  if (G_rand(maxint) == 0):
  #    ant.state = st1
  #  else:
  #    ant.state = st2
  #  return
  def flip2(bob, ant):
    if (random.randint(0,maxint) == 0):
      ant.state = st1
    else:
      ant.state = st2
    return
  return flip2

def check_friend(bob, ant, where):
  other = bob[where]['ant']
  return (other and (other.color == ant.color))
def check_foe(bob, ant, where):
  other = bob[where]['ant']
  return (other and (other.color != ant.color))
def check_food(bob, ant, where):
  return bob[where]['food']
def check_food_friend(bob, ant, where):
  other = bob[where]['ant']
  return (other and (other.color == ant.color) and other.has_food)
def check_food_foe(bob, ant, where):
  other = bob[where]['ant']
  return (other and (other.color != ant.color) and other.has_food)
def check_rock(bob, ant, where):
  return (bob[where]['type'] == board.ROCK)
def check_foe_mark(bob, ant, where):
  if (ant.color == board.RED_HILL):
    return filter(None, bob[where]['blkmark'])
  else:
    return filter(None, bob[where]['redmark'])
def check_foe_mark(bob, ant, where):
  if (ant.color == board.RED_HILL):
    return filter(None, bob[where]['blkmark'])
  else:
    return filter(None, bob[where]['redmark'])
def check_home(bob, ant, where):
  return (ant.color == bob[where]['hill'])
def check_foe_home(bob, ant, where):
  hill = bob[where]['hill']
  return (hill and (ant.color != hill))
def mk_check_mark(int_mark):
  def check_mark(bob, ant, where):
    if (ant.color == board.RED_HILL):
      return bob[where]['redmark'][int_mark]
    else:
      return bob[where]['blkmark'][int_mark]
  return check_mark

def mk_sense(dir, st1, st2, cond, marker_num=None):
  if (cond == 'Friend'):
    check = check_friend
  elif (cond == 'Foe'):
    check = check_foe
  elif (cond == 'Food'):
    check = check_food
  elif (cond == 'FriendWithFood'):
    check = check_food_friend
  elif (cond == 'FoeWithFood'):
    check = check_food_foe
  elif (cond == 'Rock'):
    check = check_rock
  elif (cond == 'FoeMarker'):
    check = check_foe_mark
    # Note the final function, check_any_marker_at. Ants of a given color can individually sense, set, and clear all 6 of their own markers,
    # but are only able to detect the presence of some marker belonging to the other species.
  elif (cond == 'Home'):
    check = check_home
  elif (cond == 'FoeHome'):
    check = check_foe_home
  elif (cond == 'Marker' and marker_num is not None):
    check = mk_check_mark(marker_num)
  else:
    raise Exception("Unknown Sense Condition %s" % (repr(cond)))
    
  def sense(bob, ant):
    if (dir == 'Here'):
      pos = ant.pos
    elif (dir == 'Ahead'):
      pos = bob.rel_pos(ant.pos, ant.dir)
    elif (dir == 'LeftAhead'):
      reldir = board.dir_left(ant.dir)
      pos = bob.rel_pos(ant.pos, reldir)
    elif (dir == 'RightAhead'):
      reldir = board.dir_right(ant.dir)
      pos = bob.rel_pos(ant.pos, reldir)
    else:
      raise Exception("Unk dir %s" % (repr(dir)))
    if (check(bob, ant, pos)):
      ant.state = st1
    else:
      ant.state = st2
  return sense

class Program(object):
  """
  Sense sensedir st1 st2 cond 	     	Go to state st1 if cond holds in sensedir;
                         	     	and to state st2 otherwise.
    <sensedir> in:
      Here
      Ahead
      LeftAhead
      RightAhead
    <cond> in:
    Friend             /* cell contains an ant of the same color */
    | Foe                /* cell contains an ant of the other color */
    | FriendWithFood     /* cell contains an ant of the same color carrying food */
    | FoeWithFood        /* cell contains an ant of the other color carrying food */
    | Food               /* cell contains food (not being carried by an ant) */
    | Rock               /* cell is rocky */
    | Marker(marker)     /* cell is marked with a marker of this ant's color */
    | FoeMarker          /* cell is marked with *some* marker of the other color */
    | Home               /* cell belongs to this ant's anthill */
    | FoeHome            /* cell belongs to the other anthill */
  Mark i st 	     	Set mark i in current cell and go to st.
    <i> in 0..5
  Unmark i st 	     	Clear mark i in current cell and go to st.
  PickUp st1 st2     	Pick up food from current cell and go to st1;
         	     	go to st2 if there is no food in the current cell.
  Drop st 	     	Drop food in current cell and go to st.
  Turn lr st 	     	Turn left or right and go to st.
  Move st1 st2 	     	Move forward and go to st1;
        	     	go to st2 if the cell ahead is blocked.
  Flip p st1 st2     	Choose a random number x from 0 to p-1;
         	     	go to st1 if x=0 and st2 otherwise.
  """
  makers = {'Sense':mk_sense,
            'Mark':mk_mark,
            'Unmark':mk_unmark,
            'PickUp':mk_pickup,
            'Drop':mk_drop,
            'Turn':mk_turn,
            'Move':mk_move,
            'Flip':mk_flip,
           }
  def __init__(self, board, filename = None, **opts):
    self.board = board
    self.states = [None] * 10000
    self.debug = None
    if (filename):
      self.parse_program(filename)
    return

  def parse_program(self, filename):
    import debugger
    (final, intermed) = preprocessor(open(filename).readlines())
    self.intermed = intermed
    self.final_prog = final
    for (i, line) in enumerate(final):
      line = line.strip()
      parts = line.split()
      if (not parts):
        continue
      cmd = parts[0]
      args = map(try_intify, parts[1:])
      try:
        make_func = self.makers[cmd]
        self.states[i] = make_func(*args)
      except Exception, e:
        print "ERROR!", str(e)
        print make_func, args
        lines = filter(lambda x:x['line_no'] == i, self.intermed)
        print lines
        raise Exception("FOO")
      #print i, self.states[i], line
    return
  
  def execute(self, ant):
    if (ant.resting):
      ant.resting -= 1
    if (ant.resting): # STILL has leftover
      return
    func = self.states[ant.state]
    func(self.board, ant) # modifies ant, and maybe board
    ant.stepped += 1
    return

  def execute_all(self, antlist):
    for (ant) in antlist:
      self.execute(ant)
    return

class MetaCommand(Exception): pass
def mk_change_type(to_color):
  colors = [(0,0,0), (0,255,0), (255,255,255), (255,255,0), (0,255,255), (128,0,210), (50,50,200)]
  color_map = dict(zip(range(len(colors)), colors))
  color = color_map[to_color]
  def change_type(board, ant):
    newimg = ant.base_image.copy()
    (w, h) = newimg.size
    border_color = newimg.getpixel((0,0))
    for (x) in range(w):
      for (y) in range(h):
        if (newimg.getpixel((x,y)) not in [ant.fill_color, border_color]):
          newimg.putpixel((x,y), color)
    ant.image = newimg
    ant.type = to_color
    raise MetaCommand()
  return change_type

def mk_assert(*bits):
  dothis = ''.join(bits)
  def myassert(board, ant):
    yesno = eval('1 and (%s)' % (dothis))
    if (not yesno):
      raise Exception("Ant failed assert %s" % (repr(dothis)))
    raise MetaCommand()
  def pass_assert(board, ant):
    raise MetaCommand()
  return pass_assert
  return myassert

def mk_print(*bits):
  dothis = ''.join(bits)
  def myprint(board, ant):
    printme = eval(dothis)
    print dothis, repr(printme)
    raise MetaCommand()
  return myprint

class SuperProgram(Program):
  """ Superset of commands in Program """
  makers = Program.makers.copy()
  makers['Type'] = mk_change_type
  makers['Assert'] = mk_assert
  makers['Print'] = mk_print
  def parse_program(self, *args):
    Program.parse_program(self, *args)
    fob = open('prog.out', 'w+')
    for (line) in self.final_prog:
      line = line.strip()
      parts = line.split()
      if (not parts):
        continue
      cmd = parts[0]
      if (cmd in Program.makers):
        fob.write(line+"\n")
    fob.close()
    return
  def execute(self, ant):
    try:
      Program.execute(self,ant)
    except MetaCommand: # do again, doesn't consume a move
      ant.state += 1
      self.execute(ant)
    return
  

def tokenize(line):
  """ split into words, tuples for (label, offset) and ints """
  # first, look for label targets
  from_re = re.compile('@(\!?\w+)([+-]?\d*)(.*)')
  curr_line = []
  if (from_re.search(line) and not line.startswith('include')):
    while (from_re.search(line)):
      m = from_re.search(line)
      curr_line.append(line[:m.start()])
      lab = m.group(1)
      try:
        offset = int(m.group(2))
      except ValueError:
        offset = 0
      curr_line.append((lab, offset))
      line = m.group(3)
    curr_line.append(line)
  else:
    curr_line = [line]
  final = []
  for (p) in curr_line:
    if (type(p) == type((1,2,3))): # (label, offset) tuple
      final.append(p)
    else: # split and intify
      for (sub_p) in map(try_intify, filter(None, p.split())):
        final.append(sub_p)
  return final

def scan_and_update_vars(lines, vars):
  var_re = re.compile('VAR (.*)\s*=\s*(.*)')
  for (i, line) in enumerate(lines):
    line = line.strip()
    if (var_re.match(line)):
      (name, val) = var_re.match(line).groups()
      name = name.strip()
      val = val.strip()
      vars[name] = val
  return

Ginclude_count = 0
def pre_first(lines, filename=None, **vars):
  """ First stage of the preprocessor
      Clean the lines, return a list of dicts, each representing a line
      perform includes here
      perform variable substitution here
  """
  scan_and_update_vars(lines, vars)
  global Ginclude_count
  label_re = re.compile('^(\w+):$')
  comment_re = re.compile('^\s*#')
  i = 0
  j = 0
  out = []
  seen_labels = {}
  count = 0
  while (count < len(lines)):
    line = lines[count]
    count += 1
    line = line.strip()
    d = {'input_line_no':i,
         'line_no':j, # final line number
         'skip':0, # 'comment' or whatever if this line shouldn't appear in the output
         'nopurge':0, # if true, never skip this line
         'rawline':line, # original input
         'out':line, # current translation
         'labels':{}, # labels defined on this line
         'toks':[], # tokenized rawline
         'filename':filename,
        }
    i += 1
    if (not line):
      d['skip'] = 'whitespace'
    if (comment_re.match(line)):
      d['skip'] = 'comment'
    if (line.startswith('VAR')):
      d['skip'] = 'Var'
    if (label_re.match(line)):
      lab = label_re.match(line).group(1)
      if (lab in seen_labels):
        raise Exception("Doubled up label %s at line %d" % (line, i))
      seen_labels[lab] = 1
      d['labels'][lab] = 1
      d['skip'] = 'labeled'

    if (vars):
      parts = list(line.split())
      for (i, p) in enumerate(parts):
        if (p.startswith('$')):
          varname = p[1:]
          parts[i] = vars[varname]
      line = ' '.join(map(str, parts))
      
    if (not d['skip']):
      toks = tokenize(line)
      d['toks'] = toks
      j += 1
    else:
      d['out'] = '# ' + d['out']
    out.append(d)
    toks = d['toks']
    if (toks and toks[0] == 'include'):
      Ginclude_count += 1
      fname = toks[1]
      d['skip'] = 'Included %s %d' % (fname, Ginclude_count)
      rest = ' '.join(toks[2:])
      try:
        incvars = eval(rest)
      except Exception, e:
        print "EVAL FAIL", str(e)
        incvars = {}
      cpv = vars.copy() # include our vars
      cpv.update(incvars) # override w/ include vars
      inc_lines = open(fname).readlines()
      extras = pre_first(inc_lines, fname, **cpv)
      for (extra) in extras:
        xlabs = {}
        for (k) in extra['labels'].keys():
          xlabs['%s_%s_%d' % (fname, k, Ginclude_count)] = 1
        extra['labels'] = xlabs
        toks = extra['toks']
        for (i, t) in enumerate(toks):
          if (type(t) == type((1,2,3))):
            if (t[0].startswith('!')): # don't munge the name
              toks[i] = (t[0][1:], t[1])
            else: # make it local to the included file
              newtup = ('%s_%s_%d' % (fname, t[0], Ginclude_count), t[1])
              toks[i] = newtup
        #print "EXTRA", extra
        out.append(extra)
    if (toks and toks[0] == 'Do'):
      d['skip'] = 'Do Start'
    elif (toks and toks[0] == 'EndDo'):
      d['skip'] = 'Do End'
  #print "TOTAL"
  #for (l) in out:
  #  print l
  return out

def pre_second(lines):
  """ Second stage of the preprocessor
      Expand all commands that cause substitutions (Sleep, Do, Turn Around)
  """
  i = 0
  while (i < len(lines)):
    line = lines[i]
    toks = line['toks']
    if (toks and toks[0] == 'Sleep'):
      line['skip'] = 'Sleep expanded'
      count = toks[1]
      # create a single line that does nothing
      sleep_line = pre_first(['Flip 1 PASS PASS'])[0]
      sleep_line['nopurge'] = 'Sleep (line %d)' % (line['input_line_no'])
      sleep_line['input_line_no'] = 'Fake'
      sleep_line['line_no'] = 'Replaceme'
      for (dummy) in range(count):
        cp = copy.deepcopy(sleep_line)
        lines.insert(i+1, cp)
      i += count
    elif (toks and toks[0] == 'Do'):
      repeat_count = toks[1]
      # find the end
      j = i + 1
      repeatme = []
      while (not lines[j]['toks'] or lines[j]['toks'][0] != 'EndDo'):
        repeatme.append(lines.pop(j))
      assert(repeatme)
      offset = 2
      for (dummy) in range(repeat_count):
        for (l) in copy.deepcopy(repeatme):
          lines.insert(i+offset, l)
          offset += 1
      i += len(repeatme) * repeat_count
    elif (toks and toks[0] == 'Turn' and toks[1] == 'Around'):
      line['skip'] = 'Turn Around (Rightx3)'
      turn_line = pre_first(['Turn Right PASS'])[0]
      turn_line['nopurge'] = 'Turn Aournd'
      turn_line['input_line_no'] = 'Fake (Turn Around)'
      turn_line['line_no'] = 'Replaceme'
      # Insert one w/ whatever the last tok is
      cp = copy.deepcopy(turn_line)
      cp['toks'][2] = toks[2]
      lines.insert(i+1, cp) # turns out third
      # Insert two w/ PASS
      lines.insert(i+1, copy.deepcopy(turn_line))
      lines.insert(i+1, copy.deepcopy(turn_line))
      i += 3
    i += 1
  return lines

def pre_renumber(obs):
  i = 0
  for (ob) in obs:
    if (not ob['skip']):
      ob['line_no'] = i
      i += 1
    else:
      ob['line_no'] = 'NA'
  return

def pre_third(obs):
  """ Third pass, remove redundant lines """

  # remove NOOP Gotos
  for (i, ob) in enumerate(obs):
    toks = ob['toks']
    if (toks and toks[0] == 'Goto'):
      lab = toks[1][1:]
      j = i + 1
      while (j < len(obs)):
        if (obs[j]['labels']): # might be our guy
          if (lab in obs[j]['labels']): # we are redundant
            ob['skip'] = 'Redundant Goto'
            break
          else:
            j += 1
        elif (obs[j]['skip']): # not real lines
          j += 1
        else:
          break
  # Mark -1 is invalid, mark it skipped
  for (i, ob) in enumerate(obs):
    toks = ob['toks']
    if (toks and toks[0] == 'Mark'):
      if (toks[1] < 0):
        obs[i]['skip'] = 'Invalid Mark'
  # Sense Mark -1 is invalid, mark it skipped
  for (i, ob) in enumerate(obs):
    toks = ob['toks']
    if (toks and toks[0] == 'Sense' and toks[-2] == 'Marker' and toks[-1] == '-1'):
      if (toks[1] < 0):
        obs[i]['skip'] = 'Invalid Sense Marker'
  return

def pre_fourth(obs):
  """ Fourth pass, translate labels and update 'out' final-ish code """
  labels = {} # {'name':line_no}

  # look for labels and store their positions
  for (i, ob) in enumerate(obs):
    if (ob['labels']):
      # find the next guy w/ a real line number
      j = i + 1
      #print (i, j), ob
      while (obs[j]['line_no'] == 'NA'):
        j += 1
      for (lab) in ob['labels'].keys():
        labels[lab] = obs[j]['line_no']
  # blow through tokens looking for labels
  for (ob) in obs:
    toks = ob['toks']
    for (i, t) in enumerate(toks):
      if (type(t) == type((1,2,3))): # label, offset
        (lab, offset) = t
        toks[i] = labels[lab] + offset

  # translate PASS, REPEAT
  for (i, ob) in enumerate(obs):
    #print i, ob
    pass_to = None
    for (k, t) in enumerate(ob['toks']):
      if (t == 'PASS'):
        if (pass_to is None):
          # find the next guy w/ a real line number
          j = i + 1
          while (obs[j]['line_no'] == 'NA'):
            j += 1
        ob['toks'][k] = obs[j]['line_no']
      elif (t == 'REPEAT'):
        ob['toks'][k] = ob['line_no']

  # translate Goto's
  for (ob) in obs:
    toks = ob['toks']
    if (toks and toks[0] == 'Goto'):
      toks[0] = 'Flip 1 %d' % (toks[1])
  
  for (ob) in obs: # relalc 'out'
    ob['out'] = ' '.join(map(str, ob['toks']))
  return

def preprocessor(lines):
  # first pass, cleanup the lines and translate to intermediate dicts, also tokenize
  obs = pre_first(lines)
  pre_renumber(obs)
  # second pass, do all substitutions that cause expansion, eg 'Sleep 4'
  pre_second(obs)
  pre_renumber(obs)
  # third, do all optimizations that remove stuff
  pre_third(obs)
  pre_renumber(obs)
  # fourth, translate labels
  pre_fourth(obs)
  
  #for (line) in obs:
  #  print repr(line['out']), line
  plines = []
  for (line) in obs:
    #print '% .4s' % (str(line['line_no'])), line['out'], '#', line['rawline']
    if (not line['skip']):
      #print line['out']
      plines.append(line['out'])

  #for (i, line) in enumerate(plines):
  #  print '% .5d %s' % (i, line)

  #for (o) in obs:
  #  print o
  #for (l) in plines:
  #  print l
  return (plines, obs)
    
def test_pre():
  fob = open('test.ant')
  #fob = open('testgoto.ant')
  lines = map(lambda x:x.strip(), fob.readlines())
  preprocessor(lines)
  
if (__name__ == '__main__'):
  p = SuperProgram(None, 'test.ant')
  #r = Rand(12345)
  #for (i) in range(100):
  #  print r.make_x()
  #  r.call_cnt += 1
  #test_pre()

