All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.tajo.plan.function.PythonAggFunctionInvoke Maven / Gradle / Ivy

The newest version!
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.tajo.plan.function;

import org.apache.tajo.catalog.CatalogUtil;
import org.apache.tajo.catalog.FunctionDesc;
import org.apache.tajo.common.TajoDataTypes;
import org.apache.tajo.datum.Datum;
import org.apache.tajo.datum.DatumFactory;
import org.apache.tajo.datum.NullDatum;
import org.apache.tajo.plan.function.python.PythonScriptEngine;
import org.apache.tajo.storage.Tuple;

import java.io.IOException;

public class PythonAggFunctionInvoke extends AggFunctionInvoke implements Cloneable {

  private transient PythonScriptEngine scriptEngine;
  private transient PythonAggFunctionContext prevContext;
  private static int nextContextId = 0;

  /**
   * Aggregated result should be kept in Tajo task rather than Python UDAF to control memory usage.
   * {@link PythonAggFunctionContext} is to support executing aggregation with keys.
   * It stores a snapshot of Python UDAF class instance as a json string.
   *
   * For each UDAF call with different aggregation key,
   * {@link PythonAggFunctionInvoke} calls {@link PythonAggFunctionInvoke#updateContextIfNecessary} to backup and restore
   * intermediate aggregation states for the previous key and the current key, respectively.
   */
  public static class PythonAggFunctionContext implements FunctionContext {
    final int id; // id to identify each context
    String jsonData; // snapshot of Python class

    public PythonAggFunctionContext() {
      this.id = nextContextId++;
    }

    public void setJsonData(String jsonData) {
      this.jsonData = jsonData;
    }

    public String getJsonData() {
      return jsonData;
    }
  }

  public PythonAggFunctionInvoke(FunctionDesc functionDesc) {
    super(functionDesc);
  }

  @Override
  public void init(FunctionInvokeContext context) throws IOException {
    this.scriptEngine = (PythonScriptEngine) context.getScriptEngine();
  }

  @Override
  public FunctionContext newContext() {
    return new PythonAggFunctionContext();
  }

  /**
   * Context does not need to be updated per every UDAF call.
   * If the current aggregation key is same with the previous one,
   * python-side context doesn't need to be updated because it already contains necessary intermediate result.
   *
   * @param context
   */
  private void updateContextIfNecessary(FunctionContext context) {
    PythonAggFunctionContext givenContext = (PythonAggFunctionContext) context;
    if (prevContext == null || prevContext.id != givenContext.id) {
      try {
        if (prevContext != null) {
          scriptEngine.updateJavaSideContext(prevContext);
        }
        scriptEngine.updatePythonSideContext(givenContext);
        prevContext = givenContext;
      } catch (IOException e) {
        throw new RuntimeException(e);
      }
    }
  }

  @Override
  public void eval(FunctionContext context, Tuple params) {
    updateContextIfNecessary(context);
    scriptEngine.callAggFunc(context, params);
  }

  @Override
  public void merge(FunctionContext context, Tuple params) {
    if (params.isBlankOrNull(0)) {
      return;
    }

    updateContextIfNecessary(context);
    scriptEngine.callAggFunc(context, params);
  }

  @Override
  public Datum getPartialResult(FunctionContext context) {
    updateContextIfNecessary(context);
    // partial results are stored as json strings.
    String result = scriptEngine.getPartialResult(context);
    return DatumFactory.createText(result);
  }

  @Override
  public TajoDataTypes.DataType getPartialResultType() {
    return CatalogUtil.newSimpleDataType(TajoDataTypes.Type.TEXT);
  }

  @Override
  public Datum terminate(FunctionContext context) {
    updateContextIfNecessary(context);
    return scriptEngine.getFinalResult(context);
  }

  @Override
  public Object clone() throws CloneNotSupportedException {
    // nothing to do
    return super.clone();
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy