All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lucene.util.automaton.UTF32ToUTF8.py Maven / Gradle / Ivy

The newest version!
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import types
import os
import sys
import random

MAX_UNICODE = 0x10FFFF

# TODO
#   - could be more minimal
#     - eg when bracket lands on a utf8 boundary, like 3 - 2047 -- they can share the two * edges
#     - also 3 2048 or 3 65536 -- it should not have an * down the red path, but it does

# MASKS[0] is bottom 1-bit
# MASKS[1] is bottom 2-bits
# ...

utf8Ranges = [(0, 127),
              (128, 2047),
              (2048, 65535),
              (65536, 1114111)]

typeToColor = {'startend': 'purple',
               'start': 'blue',
               'end': 'red'}

class FSA:

  def __init__(self):
    # maps fromNode -> (startUTF8, endUTF8, endNode)
    self.states = {}
    self.nodeUpto = 0

  def run(self, bytes):
    state = self.start
    for b in bytes:
      found = False
      oldState = state
      for label, s, e, n in self.states[state][1:]:
        if b >= s and b <= e:
          if found:
            raise RuntimeError('state %s has ambiguous output for byte %s' % (oldState, b))
          state = n
          found = True
      if not found:
        return -1
      
    return state
        
  def addEdge(self, n1, n2, v1, v2, label):
    """
    Adds edge from n1-n2, utf8 byte range v1-v2.
    """
    assert n1 in self.states
    assert type(v1) is types.IntType
    assert type(v2) is types.IntType
    self.states[n1].append((label, v1, v2, n2))

  def addNode(self, label=None):
    try:
      self.states[self.nodeUpto] = [label]
      return self.nodeUpto
    finally:
      self.nodeUpto += 1

  def toDOT(self, label):
    __l = []
    w = __l.append
    endNode = startNode = None
    for id, details in self.states.items():
      name = details[0]
      if name == 'end':
        endNode = id
      elif name == 'start':
        startNode = id

    w('digraph %s {' % label)
    w('  rankdir=LR;')
    w('  size="8,5";')
    w('  node [color=white label=""]; Ns;')

    w('  node [color=black];')
    w('  node [shape=doublecircle, label=""]; N%s [label="%s"];' % (endNode, endNode))
    w('  node [shape=circle];')

    w('  N%s [label="%s"];' % (startNode, startNode))
    w('  Ns -> N%s;' % startNode)
    for id, details in self.states.items():
      edges = details[1:]
      w('  N%s [label="%s"];' % (id, id))
      for type, s, e, dest in edges:
        c = typeToColor.get(type, 'black')
        if type == 'all*':
          # special case -- matches any utf8 byte at this point
          label = '*'
        elif s == e:
          label = '%s' % binary(s)
        else:
          label = '%s-%s' % (binary(s), binary(e))
        w('  N%s -> N%s [label="%s" color="%s"];' % (id, dest, label, c))
      if name == 'end':
        endNode = id
      elif name == 'start':
        startNode = id
    w('}')
    return '\n'.join(__l)

  def toPNG(self, label, pngOut):
    open('tmp.dot', 'wb').write(self.toDOT(label))
    if os.system('dot -Tpng tmp.dot -o %s' % pngOut):
      raise RuntimeException('dot failed')
    

MASKS = []
v = 2
for i in range(32):
  MASKS.append(v-1)
  v *= 2

def binary(x):
  if x == 0:
    return '00000000'
  
  l = []
  while x > 0:
    if x & 1 == 1:
      l.append('1')
    else:
      l.append('0')
    x = x >> 1

  # big endian!
  l.reverse()

  l2 = []
  while len(l) > 0:
    s = ''.join(l[-8:])
    if len(s) < 8:
      s = '0'*(8-len(s)) + s
    l2.append(s)
    del l[-8:]

  return ' '.join(l2)

def getUTF8Rest(code, numBytes):
  l = []
  for i in range(numBytes):
    l.append((128 | (code & MASKS[5]), 6))
    code = code >> 6
  l.reverse()
  return tuple(l)

def toUTF8(code):
  # code = Unicode code point
  assert code >= 0
  assert code <= MAX_UNICODE

  if code < 128:
    # 0xxxxxxx
    bytes = ((code, 7),)
  elif code < 2048:
    # 110yyyxx 10xxxxxx
    byte1 = (6 << 5) | (code >> 6)
    bytes = ((byte1, 5),) + getUTF8Rest(code, 1)
  elif code < 65536:
    # 1110yyyy 10yyyyxx 10xxxxxx
    len = 3
    byte1 = (14 << 4) | (code >> 12)
    bytes = ((byte1, 4),) + getUTF8Rest(code, 2)
  else:
    # 11110zzz 10zzyyyy 10yyyyxx 10xxxxxx
    len = 4
    byte1 = (30 << 3) | (code >> 18)
    bytes = ((byte1, 3),) + getUTF8Rest(code, 3)

  return bytes

def all(fsa, startNode, endNode, startCode, endCode, left):
  if len(left) == 0:
    fsa.addEdge(startNode, endNode, startCode, endCode, 'all')
  else:
    lastN = fsa.addNode()
    fsa.addEdge(startNode, lastN, startCode, endCode, 'all')
    while len(left) > 1:
      n = fsa.addNode()
      fsa.addEdge(lastN, n, 128, 191, 'all*')
      left = left[1:]
      lastN = n
    fsa.addEdge(lastN, endNode, 128, 191, 'all*')
          
def start(fsa, startNode, endNode, utf8, doAll):
  if len(utf8) == 1:
    fsa.addEdge(startNode, endNode, utf8[0][0], utf8[0][0] | MASKS[utf8[0][1]-1], 'start')
  else:
    n = fsa.addNode()
    fsa.addEdge(startNode, n, utf8[0][0], utf8[0][0], 'start')
    start(fsa, n, endNode, utf8[1:], True)
    end = utf8[0][0] | MASKS[utf8[0][1]-1]
    if doAll and utf8[0][0] != end:
      all(fsa, startNode, endNode, utf8[0][0]+1, end, utf8[1:])

def end(fsa, startNode, endNode, utf8, doAll):
  if len(utf8) == 1:
    fsa.addEdge(startNode, endNode, utf8[0][0] & ~MASKS[utf8[0][1]-1], utf8[0][0], 'end')
  else:
    if utf8[0][1] == 5:
      # special case -- avoid created unused edges (utf8 doesn't accept certain byte sequences):
      start = 194
    else:
      start = utf8[0][0] & (~MASKS[utf8[0][1]-1])
    if doAll and utf8[0][0] != start:
      all(fsa, startNode, endNode, start, utf8[0][0]-1, utf8[1:])
    n = fsa.addNode()
    fsa.addEdge(startNode, n, utf8[0][0], utf8[0][0], 'end')
    end(fsa, n, endNode, utf8[1:], True)

def build(fsa,
          startNode, endNode,
          startUTF8, endUTF8):

  # Break into start, middle, end:
  if startUTF8[0][0] == endUTF8[0][0]:
    # Degen case: lead with the same byte:
    if len(startUTF8) == 1 and len(endUTF8) == 1:
      fsa.addEdge(startNode, endNode, startUTF8[0][0], endUTF8[0][0], 'startend')
      return
    else:
      assert len(startUTF8) != 1
      assert len(endUTF8) != 1
      n = fsa.addNode()
      # single value edge
      fsa.addEdge(startNode, n, startUTF8[0][0], startUTF8[0][0], 'single')
      build(fsa, n, endNode, startUTF8[1:], endUTF8[1:])
  elif len(startUTF8) == len(endUTF8):
    if len(startUTF8) == 1:
      fsa.addEdge(startNode, endNode, startUTF8[0][0], endUTF8[0][0], 'startend')
    else:
      start(fsa, startNode, endNode, startUTF8, False)
      if endUTF8[0][0] - startUTF8[0][0] > 1:
        all(fsa, startNode, endNode, startUTF8[0][0]+1, endUTF8[0][0]-1, startUTF8[1:])
      end(fsa, startNode, endNode, endUTF8, False)
  else:
    # start
    start(fsa, startNode, endNode, startUTF8, True)

    # possibly middle
    byteCount = 1+len(startUTF8)
    while byteCount < len(endUTF8):
      s = toUTF8(utf8Ranges[byteCount-1][0])
      e = toUTF8(utf8Ranges[byteCount-1][1])
      all(fsa, startNode, endNode,
          s[0][0],
          e[0][0],
          s[1:])
      byteCount += 1

    # end
    end(fsa, startNode, endNode, endUTF8, True)

def main():

  if len(sys.argv) not in (3, 4):
    print
    print 'Usage: python %s startUTF32 endUTF32 [testCode]' % sys.argv[0]
    print
    sys.exit(1)

  utf32Start = int(sys.argv[1])
  utf32End = int(sys.argv[2])

  if utf32Start > utf32End:
    print 'ERROR: start must be <= end'
    sys.exit(1)

  fsa = FSA()
  fsa.start = fsa.addNode('start')
  fsa.end = fsa.addNode('end')

  print 's=%s' % ' '.join([binary(x[0]) for x in toUTF8(utf32Start)])
  print 'e=%s' % ' '.join([binary(x[0]) for x in toUTF8(utf32End)])

  if len(sys.argv) == 4:
    print 't=%s [%s]' % \
          (' '.join([binary(x[0]) for x in toUTF8(int(sys.argv[3]))]),
           ' '.join(['%2x' % x[0] for x in toUTF8(int(sys.argv[3]))]))
  
  build(fsa, fsa.start, fsa.end,
        toUTF8(utf32Start),
        toUTF8(utf32End))

  fsa.toPNG('test', '/tmp/outpy.png')
  print 'Saved to /tmp/outpy.png...'

  test(fsa, utf32Start, utf32End, 100000);

def test(fsa, utf32Start, utf32End, count):

  # verify correct ints are accepted
  for i in range(count):
    r = random.randint(utf32Start, utf32End)
    dest = fsa.run([tup[0] for tup in toUTF8(r)])
    if dest != fsa.end:
      print 'FAILED: valid %s (%s) is not accepted' % (r, ' '.join([binary(x[0]) for x in toUTF8(r)]))
      return False

  invalidRange = MAX_UNICODE - (utf32End - utf32Start + 1)
  if invalidRange >= 0:
    # verify invalid ints are not accepted
    for i in range(count):
      r = random.randint(0, invalidRange-1)
      if r >= utf32Start:
        r = utf32End + 1 + r - utf32Start
      dest = fsa.run([tup[0] for tup in toUTF8(r)])
      if dest != -1:
        print 'FAILED: invalid %s (%s) is accepted' % (r, ' '.join([binary(x[0]) for x in toUTF8(r)]))
        return False

  return True

def stress():

  print 'Testing...'

  iter = 0
  while True:
    if iter % 10 == 0:
      print '%s...' % iter
    iter += 1

    v1 = random.randint(0, MAX_UNICODE)
    v2 = random.randint(0, MAX_UNICODE)
    if v2 < v1:
      v1, v2 = v2, v1

    utf32Start = v1
    utf32End = v2

    fsa = FSA()
    fsa.start = fsa.addNode('start')
    fsa.end = fsa.addNode('end')
    build(fsa, fsa.start, fsa.end,
          toUTF8(utf32Start),
          toUTF8(utf32End))

    if not test(fsa, utf32Start, utf32End, 10000):
      print 'FAILED on utf32Start=%s utf32End=%s' % (utf32Start, utf32End)

if __name__ == '__main__':
  if len(sys.argv) > 1:
    main()
  else:
    stress()




© 2015 - 2025 Weber Informatics LLC | Privacy Policy