All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lucene.util.packed.gen_BulkOperation.py Maven / Gradle / Ivy

There is a newer version: 2024.11.18751.20241128T090041Z-241100
Show newest version
#! /usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from fractions import gcd

"""Code generation for bulk operations"""

MAX_SPECIALIZED_BITS_PER_VALUE = 24;
PACKED_64_SINGLE_BLOCK_BPV = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 16, 21, 32]
OUTPUT_FILE = "BulkOperation.java"
HEADER = """// This file has been automatically generated, DO NOT EDIT

package org.apache.lucene.util.packed;

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

"""

FOOTER="""
  protected int writeLong(long block, byte[] blocks, int blocksOffset) {
    for (int j = 1; j <= 8; ++j) {
      blocks[blocksOffset++] = (byte) (block >>> (64 - (j << 3)));
    }
    return blocksOffset;
  }

  /**
   * For every number of bits per value, there is a minimum number of
   * blocks (b) / values (v) you need to write in order to reach the next block
   * boundary:
   *  - 16 bits per value -> b=2, v=1
   *  - 24 bits per value -> b=3, v=1
   *  - 50 bits per value -> b=25, v=4
   *  - 63 bits per value -> b=63, v=8
   *  - ...
   *
   * A bulk read consists in copying iterations*v values that are
   * contained in iterations*b blocks into a long[]
   * (higher values of iterations are likely to yield a better
   * throughput) => this requires n * (b + 8v) bytes of memory.
   *
   * This method computes iterations as
   * ramBudget / (b + 8v) (since a long is 8 bytes).
   */
  public final int computeIterations(int valueCount, int ramBudget) {
    final int iterations = ramBudget / (byteBlockCount() + 8 * byteValueCount());
    if (iterations == 0) {
      // at least 1
      return 1;
    } else if ((iterations - 1) * byteValueCount() >= valueCount) {
      // don't allocate for more than the size of the reader
      return (int) Math.ceil((double) valueCount / byteValueCount());
    } else {
      return iterations;
    }
  }
}
"""

def is_power_of_two(n):
  return n & (n - 1) == 0

def casts(typ):
  cast_start = "(%s) (" %typ
  cast_end = ")"
  if typ == "long":
    cast_start = ""
    cast_end = ""
  return cast_start, cast_end

def hexNoLSuffix(n):
  # On 32 bit Python values > (1 << 31)-1 will have L appended by hex function:
  s = hex(n)
  if s.endswith('L'):
    s = s[:-1]
  return s

def masks(bits):
  if bits == 64:
    return "", ""
  return "(", " & %sL)" %(hexNoLSuffix((1 << bits) - 1))

def get_type(bits):
  if bits == 8:
    return "byte"
  elif bits == 16:
    return "short"
  elif bits == 32:
    return "int"
  elif bits == 64:
    return "long"
  else:
    assert False

def block_value_count(bpv, bits=64):
  blocks = bpv
  values = blocks * bits / bpv
  while blocks % 2 == 0 and values % 2 == 0:
    blocks /= 2
    values /= 2
  assert values * bpv == bits * blocks, "%d values, %d blocks, %d bits per value" %(values, blocks, bpv)
  return (blocks, values)

def packed64(bpv, f):
  mask = (1 << bpv) - 1

  f.write("\n")
  f.write("  public BulkOperationPacked%d() {\n" %bpv)
  f.write("    super(%d);\n" %bpv)
  f.write("  }\n\n")

  if bpv == 64:
    f.write("""    @Override
    public void decode(long[] blocks, int blocksOffset, long[] values, int valuesOffset, int iterations) {
      System.arraycopy(blocks, blocksOffset, values, valuesOffset, valueCount() * iterations);
    }

    @Override
    public void decode(long[] blocks, int blocksOffset, int[] values, int valuesOffset, int iterations) {
      throw new UnsupportedOperationException();
    }

    @Override
    public void decode(byte[] blocks, int blocksOffset, int[] values, int valuesOffset, int iterations) {
      throw new UnsupportedOperationException();
    }

    @Override
    public void decode(byte[] blocks, int blocksOffset, long[] values, int valuesOffset, int iterations) {
      LongBuffer.wrap(values, valuesOffset, iterations * valueCount()).put(ByteBuffer.wrap(blocks, blocksOffset, 8 * iterations * blockCount()).asLongBuffer());
    }
""")
  else:
    p64_decode(bpv, f, 32)
    p64_decode(bpv, f, 64)

def p64_decode(bpv, f, bits):
  blocks, values = block_value_count(bpv)
  typ = get_type(bits)
  cast_start, cast_end = casts(typ)

  f.write("  @Override\n")
  f.write("  public void decode(long[] blocks, int blocksOffset, %s[] values, int valuesOffset, int iterations) {\n" %typ)
  if bits < bpv:
    f.write("    throw new UnsupportedOperationException();\n")
  else:
    f.write("    for (int i = 0; i < iterations; ++i) {\n")
    mask = (1 << bpv) - 1

    if is_power_of_two(bpv):
      f.write("      final long block = blocks[blocksOffset++];\n")
      f.write("      for (int shift = %d; shift >= 0; shift -= %d) {\n" %(64 - bpv, bpv))
      f.write("        values[valuesOffset++] = %s(block >>> shift) & %d%s;\n" %(cast_start, mask, cast_end))
      f.write("      }\n") 
    else:
      for i in xrange(0, values):
        block_offset = i * bpv / 64
        bit_offset = (i * bpv) % 64
        if bit_offset == 0:
          # start of block
          f.write("      final long block%d = blocks[blocksOffset++];\n" %block_offset);
          f.write("      values[valuesOffset++] = %sblock%d >>> %d%s;\n" %(cast_start, block_offset, 64 - bpv, cast_end))
        elif bit_offset + bpv == 64:
          # end of block
          f.write("      values[valuesOffset++] = %sblock%d & %dL%s;\n" %(cast_start, block_offset, mask, cast_end))
        elif bit_offset + bpv < 64:
          # middle of block
          f.write("      values[valuesOffset++] = %s(block%d >>> %d) & %dL%s;\n" %(cast_start, block_offset, 64 - bit_offset - bpv, mask, cast_end))
        else:
          # value spans across 2 blocks
          mask1 = (1 << (64 - bit_offset)) -1
          shift1 = bit_offset + bpv - 64
          shift2 = 64 - shift1
          f.write("      final long block%d = blocks[blocksOffset++];\n" %(block_offset + 1));
          f.write("      values[valuesOffset++] = %s((block%d & %dL) << %d) | (block%d >>> %d)%s;\n" %(cast_start, block_offset, mask1, shift1, block_offset + 1, shift2, cast_end))
    f.write("    }\n")
  f.write("  }\n\n")

  byte_blocks, byte_values = block_value_count(bpv, 8)

  f.write("  @Override\n")
  f.write("  public void decode(byte[] blocks, int blocksOffset, %s[] values, int valuesOffset, int iterations) {\n" %typ)
  if bits < bpv:
    f.write("    throw new UnsupportedOperationException();\n")
  else:
    if is_power_of_two(bpv) and bpv < 8:
      f.write("    for (int j = 0; j < iterations; ++j) {\n")
      f.write("      final byte block = blocks[blocksOffset++];\n")
      for shift in xrange(8 - bpv, 0, -bpv):
        f.write("      values[valuesOffset++] = (block >>> %d) & %d;\n" %(shift, mask))
      f.write("      values[valuesOffset++] = block & %d;\n" %mask)
      f.write("    }\n")
    elif bpv == 8:
      f.write("    for (int j = 0; j < iterations; ++j) {\n")
      f.write("      values[valuesOffset++] = blocks[blocksOffset++] & 0xFF;\n")
      f.write("    }\n")
    elif is_power_of_two(bpv) and bpv > 8:
      f.write("    for (int j = 0; j < iterations; ++j) {\n")
      m = bits <= 32 and "0xFF" or "0xFFL"
      f.write("      values[valuesOffset++] =")
      for i in xrange(bpv / 8 - 1):
        f.write(" ((blocks[blocksOffset++] & %s) << %d) |" %(m, bpv - 8))
      f.write(" (blocks[blocksOffset++] & %s);\n" %m)
      f.write("    }\n")
    else:
      f.write("    for (int i = 0; i < iterations; ++i) {\n")
      for i in xrange(0, byte_values):
        byte_start = i * bpv / 8
        bit_start = (i * bpv) % 8
        byte_end = ((i + 1) * bpv - 1) / 8
        bit_end = ((i + 1) * bpv - 1) % 8
        shift = lambda b: 8 * (byte_end - b - 1) + 1 + bit_end
        if bit_start == 0:
          f.write("      final %s byte%d = blocks[blocksOffset++] & 0xFF;\n" %(typ, byte_start))
        for b in xrange(byte_start + 1, byte_end + 1):
          f.write("      final %s byte%d = blocks[blocksOffset++] & 0xFF;\n" %(typ, b))
        f.write("      values[valuesOffset++] =")
        if byte_start == byte_end:
          if bit_start == 0:
            if bit_end == 7:
              f.write(" byte%d" %byte_start)
            else:
              f.write(" byte%d >>> %d" %(byte_start, 7 - bit_end))
          else:
            if bit_end == 7:
              f.write(" byte%d & %d" %(byte_start, 2 ** (8 - bit_start) - 1))
            else:
              f.write(" (byte%d >>> %d) & %d" %(byte_start, 7 - bit_end, 2 ** (bit_end - bit_start + 1) - 1))
        else:
          if bit_start == 0:
            f.write(" (byte%d << %d)" %(byte_start, shift(byte_start)))
          else:
            f.write(" ((byte%d & %d) << %d)" %(byte_start, 2 ** (8 - bit_start) - 1, shift(byte_start)))
          for b in xrange(byte_start + 1, byte_end):
            f.write(" | (byte%d << %d)" %(b, shift(b)))
          if bit_end == 7:
            f.write(" | byte%d" %byte_end)
          else:
            f.write(" | (byte%d >>> %d)" %(byte_end, 7 - bit_end))
        f.write(";\n")
      f.write("    }\n")
  f.write("  }\n\n")

if __name__ == '__main__':
  f = open(OUTPUT_FILE, 'w')
  f.write(HEADER)
  f.write('\n')
  f.write('''/**
 * Efficient sequential read/write of packed integers.
 */\n''')

  f.write('abstract class BulkOperation implements PackedInts.Decoder, PackedInts.Encoder {\n')
  f.write('  private static final BulkOperation[] packedBulkOps = new BulkOperation[] {\n')
    
  for bpv in xrange(1, 65):
    if bpv > MAX_SPECIALIZED_BITS_PER_VALUE:
      f.write('    new BulkOperationPacked(%d),\n' % bpv)
      continue
    f2 = open('BulkOperationPacked%d.java' % bpv, 'w')
    f2.write(HEADER)
    if bpv == 64:
      f2.write('import java.nio.LongBuffer;\n')
      f2.write('import java.nio.ByteBuffer;\n')
      f2.write('\n')
    f2.write('''/**
 * Efficient sequential read/write of packed integers.
 */\n''')
    f2.write('final class BulkOperationPacked%d extends BulkOperationPacked {\n' % bpv)
    packed64(bpv, f2)
    f2.write('}\n')
    f2.close()
    f.write('    new BulkOperationPacked%d(),\n' % bpv)
    
  f.write('  };\n')
  f.write('\n')
    
  f.write('  // NOTE: this is sparse (some entries are null):\n')
  f.write('  private static final BulkOperation[] packedSingleBlockBulkOps = new BulkOperation[] {\n')
  for bpv in xrange(1, max(PACKED_64_SINGLE_BLOCK_BPV)+1):
    if bpv in PACKED_64_SINGLE_BLOCK_BPV:
      f.write('    new BulkOperationPackedSingleBlock(%d),\n' % bpv)
    else:
      f.write('    null,\n')
  f.write('  };\n')
  f.write('\n')
      
  f.write("\n")
  f.write("  public static BulkOperation of(PackedInts.Format format, int bitsPerValue) {\n")
  f.write("    switch (format) {\n")

  f.write("    case PACKED:\n")
  f.write("      assert packedBulkOps[bitsPerValue - 1] != null;\n")
  f.write("      return packedBulkOps[bitsPerValue - 1];\n")
  f.write("    case PACKED_SINGLE_BLOCK:\n")
  f.write("      assert packedSingleBlockBulkOps[bitsPerValue - 1] != null;\n")
  f.write("      return packedSingleBlockBulkOps[bitsPerValue - 1];\n")
  f.write("    default:\n")
  f.write("      throw new AssertionError();\n")
  f.write("    }\n")
  f.write("  }\n")
  f.write(FOOTER)
  f.close()




© 2015 - 2024 Weber Informatics LLC | Privacy Policy