All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.ftvec.ranking.BprSamplingUDTF Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.ftvec.ranking;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Random;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.IntWritable;

import hivemall.UDTFWithOptions;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.BitUtils;
import hivemall.utils.lang.Primitives;

@Description(name = "bpr_sampling",
        value = "_FUNC_(int userId, List posItems [, const string options])"
                + "- Returns a relation consists of ")
public final class BprSamplingUDTF extends UDTFWithOptions {

    private PrimitiveObjectInspector userOI;
    private ListObjectInspector itemListOI;
    private PrimitiveObjectInspector itemElemOI;

    // Need to avoid
    // org.apache.hive.com.esotericsoftware.kryo.KryoException: java.lang.ArrayIndexOutOfBoundsException: 1
    @Nullable
    private transient PositiveOnlyFeedback feedback;

    // sampling options
    private int maxItemId;
    private float samplingRate;
    private boolean withoutReplacement;
    private boolean pairSampling;

    private Object[] forwardObjs;
    private IntWritable userId;
    private IntWritable posItemId;
    private IntWritable negItemId;

    public BprSamplingUDTF() {}

    @Override
    protected Options getOptions() {
        Options opts = new Options();
        opts.addOption("sampling", "sampling_rate", true,
            "Sampling rates of positive items [default: 1.0]");
        opts.addOption("without_replacement", false,
            "Do sampling without-replacement sampling [default: false]");
        opts.addOption("uniform_pair_sampling", "pair_sampling", false,
            "Sampling pairs uniform from feedbacks [default: false]");
        opts.addOption("maxcol", "max_itemid", true, "Max item id index [default: -1]");
        return opts;
    }

    @Override
    protected CommandLine processOptions(@Nonnull ObjectInspector[] argOIs)
            throws UDFArgumentException {
        CommandLine cl = null;

        int maxItemId = -1;
        float samplingRate = 1.f;
        boolean withoutReplacement = false;
        boolean pairSampling = false;

        if (argOIs.length >= 3) {
            String args = HiveUtils.getConstString(argOIs[2]);
            cl = parseOptions(args);

            maxItemId = Primitives.parseInt(cl.getOptionValue("max_itemid"), maxItemId);
            withoutReplacement = cl.hasOption("without_replacement");
            pairSampling = cl.hasOption("uniform_pair_sampling");

            samplingRate = Primitives.parseFloat(cl.getOptionValue("sampling_rate"), samplingRate);
            if (withoutReplacement && samplingRate > 1.f) {
                throw new UDFArgumentException("sampling_rate MUST be in less than or equals to 1"
                        + " where without-replacement is true: " + samplingRate);
            }
        }

        this.maxItemId = maxItemId;
        this.samplingRate = samplingRate;
        this.withoutReplacement = withoutReplacement;
        this.pairSampling = pairSampling;
        return cl;
    }

    @Override
    public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs)
            throws UDFArgumentException {
        if (argOIs.length != 2 && argOIs.length != 3) {
            throw new UDFArgumentException(
                "_FUNC_(int userid, array itemid, [, const string options])"
                        + " takes at least two arguments");
        }
        this.userOI = HiveUtils.asIntegerOI(argOIs[0]);
        this.itemListOI = HiveUtils.asListOI(argOIs[1]);
        this.itemElemOI = HiveUtils.asIntegerOI(itemListOI.getListElementObjectInspector());

        processOptions(argOIs);

        this.userId = new IntWritable();
        this.posItemId = new IntWritable();
        this.negItemId = new IntWritable();
        this.forwardObjs = new Object[] {userId, posItemId, negItemId};

        ArrayList fieldNames = new ArrayList();
        ArrayList fieldOIs = new ArrayList();
        fieldNames.add("user");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("pos_item");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        fieldNames.add("neg_item");
        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);

        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    }

    @Override
    public void process(@Nonnull Object[] args) throws HiveException {
        if (feedback == null) {
            this.feedback = pairSampling ? new PerEventPositiveOnlyFeedback(maxItemId)
                    : new PositiveOnlyFeedback(maxItemId);
        }

        int userId = PrimitiveObjectInspectorUtils.getInt(args[0], userOI);
        validateIndex(userId);

        addFeedback(userId, args[1]);
    }

    @Nullable
    private void addFeedback(final int userId, @Nonnull final Object arg)
            throws UDFArgumentException {
        final int size = itemListOI.getListLength(arg);
        if (size == 0) {
            return;
        }

        int maxItemId = feedback.getMaxItemId();
        final IntArrayList posItems = new IntArrayList(size);
        for (int i = 0; i < size; i++) {
            Object elem = itemListOI.getListElement(arg, i);
            if (elem == null) {
                continue;
            }
            int index = PrimitiveObjectInspectorUtils.getInt(elem, itemElemOI);
            validateIndex(index);
            maxItemId = Math.max(index, maxItemId);
            posItems.add(index);
        }

        feedback.addFeedback(userId, posItems);
        feedback.setMaxItemId(maxItemId);
    }

    @Override
    public void close() throws HiveException {
        int feedbacks = feedback.getTotalFeedbacks();
        if (feedbacks == 0) {
            return;
        }
        int numSamples = (int) (feedbacks * samplingRate);

        if (pairSampling) {
            PerEventPositiveOnlyFeedback evFeedback = (PerEventPositiveOnlyFeedback) feedback;
            if (withoutReplacement) {
                uniformPairSamplingWithoutReplacement(evFeedback, numSamples);
            } else {
                uniformPairSamplingWithReplacement(evFeedback, numSamples);
            }
        } else {
            if (withoutReplacement) {
                uniformUserSamplingWithoutReplacement(feedback, numSamples);
            } else {
                uniformUserSamplingWithReplacement(feedback, numSamples);
            }
        }
    }

    private void forward(final int user, final int posItem, final int negItem)
            throws HiveException {
        assert (user >= 0) : user;
        assert (posItem >= 0) : posItem;
        assert (negItem >= 0) : negItem;

        userId.set(user);
        posItemId.set(posItem);
        negItemId.set(negItem);
        forward(forwardObjs);
    }

    /**
     * Sampling pairs uniform for each user with replacement. Sample a user. Then, sample a pair.
     */
    private void uniformUserSamplingWithReplacement(@Nonnull final PositiveOnlyFeedback feedback,
            final int numSamples) throws HiveException {
        final int numUsers = feedback.getNumUsers();
        if (numUsers == 0) {
            return;
        }
        final int maxItemId = feedback.getMaxItemId();
        if (maxItemId <= 0) {
            throw new HiveException("Invalid maxItemId: " + maxItemId);
        }
        final int numItems = maxItemId + 1;
        final int[] users = feedback.getUsers();
        assert (users.length == numUsers);

        final Random rand = new Random(31L);
        for (int i = 0; i < numSamples; i++) {
            int user = users[rand.nextInt(numUsers)];

            IntArrayList posItems = feedback.getItems(user, true);
            assert (posItems != null) : user;
            int size = posItems.size();
            assert (size > 0) : size;
            if (size == numItems) {// cannot draw a negative item      
                --i;
                continue;
            }

            int posItemIndex = rand.nextInt(size);
            int posItem = posItems.fastGet(posItemIndex);
            int negItem;
            do {
                negItem = rand.nextInt(maxItemId);
            } while (posItems.contains(negItem));

            forward(user, posItem, negItem);
        }
    }

    /**
     * Sampling pairs uniform for each user without replacement. Sample a user. Then, sample a pair.
     * 
     * Caution: This is not a perfect 'without sampling' but it does 'without sampling' for positive
     * feedbacks.
     */
    private void uniformUserSamplingWithoutReplacement(@Nonnull final PositiveOnlyFeedback feedback,
            final int numSamples) throws HiveException {
        int numUsers = feedback.getNumUsers();
        if (numUsers == 0) {
            return;
        }
        final int maxItemId = feedback.getMaxItemId();
        if (maxItemId <= 0) {
            throw new HiveException("Invalid maxItemId: " + maxItemId);
        }
        final int numItems = maxItemId + 1;
        final BitSet userBits = new BitSet(numUsers);
        feedback.getUsers(userBits);

        final Random rand = new Random(31L);
        for (int i = 0; i < numSamples && numUsers > 0; i++) {
            int nthUser = rand.nextInt(numUsers);
            int user = BitUtils.indexOfSetBit(userBits, nthUser);
            if (user == -1) {
                throw new HiveException(
                    "Cannot find " + nthUser + "-th user among " + numUsers + " users");
            }

            IntArrayList posItems = feedback.getItems(user, true);
            assert (posItems != null) : user;
            int size = posItems.size();
            assert (size > 0) : size;
            if (size == numItems) {// cannot draw a negative item                
                --i;
                continue;
            }

            int posItemIndex = rand.nextInt(size);
            int posItem = posItems.fastGet(posItemIndex);
            int negItem;
            do {
                negItem = rand.nextInt(maxItemId);
            } while (posItems.contains(negItem));

            posItems.remove(posItemIndex);
            if (posItems.isEmpty()) {
                feedback.removeFeedback(user);
                userBits.clear(user);
                --numUsers;
            }

            forward(user, posItem, negItem);
        }
    }

    /**
     * Sampling pairs uniform from feedbacks with replacement.
     */
    private void uniformPairSamplingWithReplacement(
            @Nonnull final PerEventPositiveOnlyFeedback feedback, final int numSamples)
            throws HiveException {
        final int numFeedbacks = feedback.getTotalFeedbacks();
        if (numFeedbacks == 0) {
            return;
        }
        final int maxItemId = feedback.getMaxItemId();
        if (maxItemId <= 0) {
            throw new HiveException("Invalid maxItemId: " + maxItemId);
        }

        final Random rand = new Random(31L);
        for (int i = 0; i < numSamples; i++) {
            int index = rand.nextInt(numFeedbacks);
            int user = feedback.getUser(index);
            int posItem = feedback.getPositiveItem(index);

            IntArrayList posItems = feedback.getItems(user, true);
            assert (posItems != null) : user;

            int negItem;
            do {
                negItem = rand.nextInt(maxItemId);
            } while (posItems.contains(negItem));

            forward(user, posItem, negItem);
        }
    }

    /**
     * Sampling pairs uniform from feedbacks without replacement.
     * 
     * Caution: This is not a perfect 'without sampling' but it does 'without sampling' for positive
     * feedbacks.
     */
    private void uniformPairSamplingWithoutReplacement(
            @Nonnull final PerEventPositiveOnlyFeedback feedback, final int numSamples)
            throws HiveException {
        final int numFeedbacks = feedback.getTotalFeedbacks();
        if (numFeedbacks == 0) {
            return;
        }
        final int maxItemId = feedback.getMaxItemId();
        if (maxItemId <= 0) {
            throw new HiveException("Invalid maxItemId: " + maxItemId);
        }

        final Random rand = new Random(31L);
        final int[] perm = feedback.getRandomIndex(rand);
        for (int index : perm) {
            int user = feedback.getUser(index);
            int posItem = feedback.getPositiveItem(index);

            IntArrayList posItems = feedback.getItems(user, true);
            assert (posItems != null) : user;

            int negItem;
            do {
                negItem = rand.nextInt(maxItemId);
            } while (posItems.contains(negItem));

            forward(user, posItem, negItem);
        }
    }

    private static void validateIndex(final int index) throws UDFArgumentException {
        if (index < 0) {
            throw new UDFArgumentException("Negative index is not allowed: " + index);
        }
    }


}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy