All Downloads are FREE. Search and download functionalities are using the official Maven repository.

streaming.controller.py Maven / Gradle / Ivy

The newest version!
#! /usr/bin/env python
import sys
import os
import logging

from datetime import datetime
try:
    from dateutil import parser
    USE_DATEUTIL = True
except ImportError:
    USE_DATEUTIL = False

from pig_util import write_user_exception, udf_logging

FIELD_DELIMITER = ','
TUPLE_START = '('
TUPLE_END = ')'
BAG_START = '{'
BAG_END = '}'
MAP_START = '['
MAP_END = ']'
MAP_KEY = '#'
PARAMETER_DELIMITER = '\t'
PRE_WRAP_DELIM = '|'
POST_WRAP_DELIM = '_'
NULL_BYTE = "-"
END_RECORD_DELIM = '|_\n'
END_RECORD_DELIM_LENGTH = len(END_RECORD_DELIM)

WRAPPED_FIELD_DELIMITER = PRE_WRAP_DELIM + FIELD_DELIMITER + POST_WRAP_DELIM
WRAPPED_TUPLE_START = PRE_WRAP_DELIM + TUPLE_START + POST_WRAP_DELIM
WRAPPED_TUPLE_END = PRE_WRAP_DELIM + TUPLE_END + POST_WRAP_DELIM
WRAPPED_BAG_START = PRE_WRAP_DELIM + BAG_START + POST_WRAP_DELIM
WRAPPED_BAG_END = PRE_WRAP_DELIM + BAG_END + POST_WRAP_DELIM
WRAPPED_MAP_START = PRE_WRAP_DELIM + MAP_START + POST_WRAP_DELIM
WRAPPED_MAP_END = PRE_WRAP_DELIM + MAP_END + POST_WRAP_DELIM
WRAPPED_PARAMETER_DELIMITER = PRE_WRAP_DELIM + PARAMETER_DELIMITER + POST_WRAP_DELIM
WRAPPED_NULL_BYTE = PRE_WRAP_DELIM + NULL_BYTE + POST_WRAP_DELIM

TYPE_TUPLE = TUPLE_START
TYPE_BAG = BAG_START
TYPE_MAP = MAP_START

TYPE_BOOLEAN = "B"
TYPE_INTEGER = "I"
TYPE_LONG = "L"
TYPE_FLOAT = "F"
TYPE_DOUBLE = "D"
TYPE_BYTEARRAY = "A"
TYPE_CHARARRAY = "C"
TYPE_DATETIME = "T"
TYPE_BIGINTEGER = "N"
TYPE_BIGDECIMAL = "E"

END_OF_STREAM = TYPE_CHARARRAY + "\x04" + END_RECORD_DELIM
TURN_ON_OUTPUT_CAPTURING = TYPE_CHARARRAY + "TURN_ON_OUTPUT_CAPTURING" + END_RECORD_DELIM
NUM_LINES_OFFSET_TRACE = int(os.environ.get('PYTHON_TRACE_OFFSET', 0))

class PythonStreamingController:
    def __init__(self, profiling_mode=False):
        self.profiling_mode = profiling_mode

        self.input_count = 0
        self.next_input_count_to_log = 1

    def main(self,
             module_name, file_path, func_name, cache_path,
             output_stream_path, error_stream_path, log_file_name, is_illustrate_str):
        sys.stdin = os.fdopen(sys.stdin.fileno(), 'rb', 0)

        #Need to ensure that user functions can't write to the streams we use to
        #communicate with pig.
        self.stream_output = os.fdopen(sys.stdout.fileno(), 'wb', 0)
        self.stream_error = os.fdopen(sys.stderr.fileno(), 'wb', 0)

        self.input_stream = sys.stdin
        self.output_stream = open(output_stream_path, 'a')
        sys.stderr = open(error_stream_path, 'w')
        is_illustrate = is_illustrate_str == "true"

        sys.path.append(file_path)
        sys.path.append(cache_path)
        sys.path.append('.')

        logging.basicConfig(filename=log_file_name, format="%(asctime)s %(levelname)s %(message)s", level=udf_logging.udf_log_level)
        logging.info("To reduce the amount of information being logged only a small subset of rows are logged at the INFO level.  Call udf_logging.set_log_level_debug in pig_util to see all rows being processed.")

        input_str = self.get_next_input()

        try:
            func = __import__(module_name, globals(), locals(), [func_name], -1).__dict__[func_name]
        except:
            #These errors should always be caused by user code.
            write_user_exception(module_name, self.stream_error, NUM_LINES_OFFSET_TRACE)
            self.close_controller(-1)

        if is_illustrate or udf_logging.udf_log_level != logging.DEBUG:
            #Only log output for illustrate after we get the flag to capture output.
            sys.stdout = open(os.devnull, 'w')
        else:
            sys.stdout = self.output_stream

        while input_str != END_OF_STREAM:
            should_log = False
            if self.input_count == self.next_input_count_to_log:
                should_log = True
                log_message = logging.info
                self.update_next_input_count_to_log()
            elif udf_logging.udf_log_level == logging.DEBUG:
                should_log = True
                log_message = logging.debug

            try:
                try:
                    if should_log:
                        log_message("Row %s: Serialized Input: %s" % (self.input_count, input_str))
                    inputs = deserialize_input(input_str)
                    if should_log:
                        log_message("Row %s: Deserialized Input: %s" % (self.input_count, unicode(inputs)))
                except:
                    #Capture errors where the user passes in bad data.
                    write_user_exception(module_name, self.stream_error, NUM_LINES_OFFSET_TRACE)
                    self.close_controller(-3)

                try:
                    func_output = func(*inputs)
                    if should_log:
                        try:
                            log_message("Row %s: UDF Output: %s" % (self.input_count, unicode(func_output)))
                        except:
                            #This is probably an error with unicoding the output.  Calling unicode on bytearray will
                            #throw an exception.  Since its just a log statement, just skip and carry on.
                            logging.exception("Couldn't log output.  Try to continue.")
                except:
                    #These errors should always be caused by user code.
                    write_user_exception(module_name, self.stream_error, NUM_LINES_OFFSET_TRACE)
                    self.close_controller(-2)

                output = serialize_output(func_output)
                if should_log:
                    log_message("Row %s: Serialized Output: %s" % (self.input_count, output))

                self.stream_output.write( "%s%s" % (output, END_RECORD_DELIM) )
            except Exception as e:
                #This should only catch internal exceptions with the controller
                #and pig- not with user code.
                import traceback
                traceback.print_exc(file=self.stream_error)
                sys.exit(-3)

            sys.stdout.flush()
            sys.stderr.flush()
            self.stream_output.flush()
            self.stream_error.flush()

            input_str = self.get_next_input()

    def get_next_input(self):
        input_stream = self.input_stream
        output_stream = self.output_stream

        input_str = input_stream.readline()

        while input_str.endswith(END_RECORD_DELIM) == False:
            line = input_stream.readline()
            if line == '':
                input_str = ''
                break
            input_str += line

        if input_str == '':
            return END_OF_STREAM

        if input_str == TURN_ON_OUTPUT_CAPTURING:
            logging.debug("Turned on Output Capturing")
            sys.stdout = output_stream
            return self.get_next_input()

        if input_str == END_OF_STREAM:
            return input_str

        self.input_count += 1

        return input_str[:-END_RECORD_DELIM_LENGTH]

    def update_next_input_count_to_log(self):
        """
        Want to log enough rows that you can see progress being made and see timings without wasting time logging thousands of rows.
        Show first 10 rows, and then the first 5 rows of every order of magnitude (10-15, 100-105, 1000-1005, ...)
        """
        if self.next_input_count_to_log < 10:
            self.next_input_count_to_log = self.next_input_count_to_log + 1
        elif self.next_input_count_to_log % 10 == 5:
            self.next_input_count_to_log = (self.next_input_count_to_log - 5) * 10
        else:
            self.next_input_count_to_log = self.next_input_count_to_log + 1

    def close_controller(self, exit_code):
        sys.stderr.close()
        self.stream_error.write("\n")
        self.stream_error.close()
        sys.stdout.close()
        self.stream_output.write("\n")
        self.stream_output.close()
        sys.exit(exit_code)

def deserialize_input(input_str):
    if len(input_str) == 0:
        return []

    return [_deserialize_input(param, 0, len(param)-1) for param in input_str.split(WRAPPED_PARAMETER_DELIMITER)]

def _deserialize_input(input_str, si, ei):
    if ei - si < 1:
        #Handle all of the cases where you can have valid empty input.
        if ei == si:
            if input_str[si] == TYPE_CHARARRAY:
                return u""
            elif input_str[si] == TYPE_BYTEARRAY:
                return bytearray("")
            else:
                raise Exception("Got input type flag %s, but no data to go with it.\nInput string: %s\nSlice: %s" % (input_str[si], input_str, input_str[si:ei+1]))
        else:
            raise Exception("Start index %d greater than end index %d.\nInput string: %s\n, Slice: %s" % (si, ei, input_str[si:ei+1]))

    first = input_str[si]
    schema = input_str[si+1] if first == PRE_WRAP_DELIM else first

    if schema == NULL_BYTE:
        return None
    elif schema == TYPE_TUPLE or schema == TYPE_MAP or schema == TYPE_BAG:
        return _deserialize_collection(input_str, schema, si+3, ei-3)
    elif schema == TYPE_CHARARRAY:
        return unicode(input_str[si+1:ei+1], 'utf-8')
    elif schema == TYPE_BYTEARRAY:
        return bytearray(input_str[si+1:ei+1])
    elif schema == TYPE_INTEGER:
        return int(input_str[si+1:ei+1])
    elif schema == TYPE_LONG or schema == TYPE_BIGINTEGER:
        return long(input_str[si+1:ei+1])
    elif schema == TYPE_FLOAT or schema == TYPE_DOUBLE or schema == TYPE_BIGDECIMAL:
        return float(input_str[si+1:ei+1])
    elif schema == TYPE_BOOLEAN:
        return input_str[si+1:ei+1] == "true"
    elif schema == TYPE_DATETIME:
        #Format is "yyyy-MM-ddTHH:mm:ss.SSS+00:00" or "2013-08-23T18:14:03.123+ZZ"
        if USE_DATEUTIL:
            return parser.parse(input_str[si+1:ei+1])
        else:
            #Try to use datetime even though it doesn't handle time zones properly,
            #We only use the first 3 microsecond digits and drop time zone (first 23 characters)
            return datetime.strptime(input_str[si+1:si+24], "%Y-%m-%dT%H:%M:%S.%f")
    else:
        raise Exception("Can't determine type of input: %s" % input_str[si:ei+1])

def _deserialize_collection(input_str, return_type, si, ei):
    list_result = []
    append_to_list_result = list_result.append
    dict_result = {}

    index = si
    field_start = si
    depth = 0

    key = None

    # recurse to deserialize elements if the collection is not empty
    if ei-si+1 > 0:
        while True:
            if index >= ei - 2:
                if return_type == TYPE_MAP:
                    dict_result[key] = _deserialize_input(input_str, value_start, ei)
                else:
                    append_to_list_result(_deserialize_input(input_str, field_start, ei))
                break

            if return_type == TYPE_MAP and not key:
                key_index = input_str.find(MAP_KEY, index)
                key = unicode(input_str[index+1:key_index], 'utf-8')
                index = key_index + 1
                value_start = key_index + 1
                continue

            if not (input_str[index] == PRE_WRAP_DELIM and input_str[index+2] == POST_WRAP_DELIM):
                prewrap_index = input_str.find(PRE_WRAP_DELIM, index+1)
                index = (prewrap_index if prewrap_index != -1 else end_index)
                continue

            mid = input_str[index+1]

            if mid == BAG_START or mid == TUPLE_START or mid == MAP_START:
                depth += 1
            elif mid == BAG_END or mid == TUPLE_END or mid == MAP_END:
                depth -= 1
            elif depth == 0 and mid == FIELD_DELIMITER:
                if return_type == TYPE_MAP:
                    dict_result[key] = _deserialize_input(input_str, value_start, index - 1)
                    key = None
                else:
                    append_to_list_result(_deserialize_input(input_str, field_start, index - 1))
                field_start = index + 3

            index += 3

    if return_type == TYPE_MAP:
        return dict_result
    elif return_type == TYPE_TUPLE:
        return tuple(list_result)
    else:
        return list_result

def wrap_tuple(o, serialized_item):
    if type(o) != tuple:
        return WRAPPED_TUPLE_START + serialized_item + WRAPPED_TUPLE_END
    else:
        return serialized_item

def serialize_output(output, utfEncodeAllFields=False):
    """
    @param utfEncodeStrings - Generally we want to utf encode only strings.  But for
        Maps we utf encode everything because on the Java side we don't know the schema
        for maps so we wouldn't be able to tell which fields were encoded or not.
    """

    output_type = type(output)

    if output is None:
        return WRAPPED_NULL_BYTE
    elif output_type == tuple:
        return (WRAPPED_TUPLE_START +
                WRAPPED_FIELD_DELIMITER.join([serialize_output(o, utfEncodeAllFields) for o in output]) +
                WRAPPED_TUPLE_END)
    elif output_type == list:
        return (WRAPPED_BAG_START +
                WRAPPED_FIELD_DELIMITER.join([wrap_tuple(o, serialize_output(o, utfEncodeAllFields)) for o in output]) +
                WRAPPED_BAG_END)
    elif output_type == dict:
        return (WRAPPED_MAP_START +
                WRAPPED_FIELD_DELIMITER.join(['%s%s%s' % (k.encode('utf-8'), MAP_KEY, serialize_output(v, True)) for k, v in output.iteritems()]) +
                WRAPPED_MAP_END)
    elif output_type == bool:
        return ("true" if output else "false")
    elif output_type == bytearray:
        return str(output)
    elif output_type == datetime:
        return output.isoformat()
    elif utfEncodeAllFields or output_type == str or output_type == unicode:
        #unicode is necessary in cases where we're encoding non-strings.
        return unicode(output).encode('utf-8')
    else:
        return str(output)

if __name__ == '__main__':
    controller = PythonStreamingController()
    controller.main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4],
                    sys.argv[5], sys.argv[6], sys.argv[7], sys.argv[8])




© 2015 - 2024 Weber Informatics LLC | Privacy Policy