All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.factorization.fm.Entry Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.factorization.fm;

import hivemall.utils.buffer.HeapBuffer;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.SizeOf;
import hivemall.utils.math.MathUtils;

import java.util.Arrays;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;

class Entry {

    @Nonnull
    protected final HeapBuffer _buf;
    @Nonnegative
    protected final int _size;
    @Nonnegative
    protected final int _factors;

    // temporary variables used only in training phase
    protected int _key;
    @Nonnegative
    protected long _offset;

    Entry(@Nonnull HeapBuffer buf, int factors) {
        this._buf = buf;
        this._size = Entry.sizeOf(factors);
        this._factors = factors;
    }

    Entry(@Nonnull HeapBuffer buf, int key, @Nonnegative long offset) {
        this(buf, 1, key, offset);
    }

    Entry(@Nonnull HeapBuffer buf, int factors, int key, @Nonnegative long offset) {
        this(buf, factors, Entry.sizeOf(factors), key, offset);
    }

    private Entry(@Nonnull HeapBuffer buf, int factors, int size, int key,
            @Nonnegative long offset) {
        this._buf = buf;
        this._size = size;
        this._factors = factors;
        this._key = key;
        this._offset = offset;
    }

    final int getSize() {
        return _size;
    }

    final int getKey() {
        return _key;
    }

    final long getOffset() {
        return _offset;
    }

    final void setOffset(final long offset) {
        this._offset = offset;
    }

    final float getW() {
        return _buf.getFloat(_offset);
    }

    final void setW(final float value) {
        _buf.putFloat(_offset, value);
    }

    final void getV(@Nonnull final float[] Vf) {
        final long offset = _offset;
        final int len = Vf.length;
        for (int f = 0; f < len; f++) {
            long index = offset + SizeOf.FLOAT * f;
            Vf[f] = _buf.getFloat(index);
        }
    }

    final void setV(@Nonnull final float[] Vf) {
        final long offset = _offset;
        final int len = Vf.length;
        for (int f = 0; f < len; f++) {
            long index = offset + SizeOf.FLOAT * f;
            _buf.putFloat(index, Vf[f]);
        }
    }

    final float getV(final int f) {
        long index = _offset + SizeOf.FLOAT * f;
        return _buf.getFloat(index);
    }

    final void setV(final int f, final float value) {
        long index = _offset + SizeOf.FLOAT * f;
        _buf.putFloat(index, value);
    }

    double getSumOfSquaredGradients(@Nonnegative int f) {
        throw new UnsupportedOperationException();
    }

    void addGradient(@Nonnegative int f, float grad) {
        throw new UnsupportedOperationException();
    }

    final float updateZ(final float gradW, final float alpha) {
        float w = getW();
        return updateZ(0, w, gradW, alpha);
    }

    float updateZ(@Nonnegative int f, float W, float gradW, float alpha) {
        throw new UnsupportedOperationException();
    }

    final double updateN(final float gradW) {
        return updateN(0, gradW);
    }

    double updateN(@Nonnegative int f, float gradW) {
        throw new UnsupportedOperationException();
    }

    boolean removable() {
        if (!isEntryW(_key)) {// entry for V
            final long offset = _offset;
            for (int f = 0; f < _factors; f++) {
                final float Vf = _buf.getFloat(offset + SizeOf.FLOAT * f);
                if (!MathUtils.closeToZero(Vf, 1E-9f)) {
                    return false;
                }
            }
        }
        return true;
    }

    void clear() {};

    static int sizeOf(@Nonnegative final int factors) {
        Preconditions.checkArgument(factors >= 1, "Factors must be greater than 0: " + factors);
        return SizeOf.FLOAT * factors;
    }

    static boolean isEntryW(final int i) {
        return i < 0;
    }

    @Override
    public String toString() {
        if (Entry.isEntryW(_key)) {
            return "W=" + getW();
        } else {
            float[] Vf = new float[_factors];
            getV(Vf);
            return "V=" + Arrays.toString(Vf);
        }
    }

    static final class AdaGradEntry extends Entry {

        final long _gg_offset;

        AdaGradEntry(@Nonnull HeapBuffer buf, int key, @Nonnegative long offset) {
            this(buf, 1, key, offset);
        }

        AdaGradEntry(@Nonnull HeapBuffer buf, @Nonnegative int factors, int key,
                @Nonnegative long offset) {
            super(buf, factors, AdaGradEntry.sizeOf(factors), key, offset);
            this._gg_offset = _offset + Entry.sizeOf(factors);
        }

        @Override
        double getSumOfSquaredGradients(@Nonnegative final int f) {
            Preconditions.checkArgument(f >= 0);

            long offset = _gg_offset + SizeOf.DOUBLE * f;
            return _buf.getDouble(offset);
        }

        @Override
        void addGradient(@Nonnegative final int f, final float grad) {
            Preconditions.checkArgument(f >= 0);

            long offset = _gg_offset + SizeOf.DOUBLE * f;
            double v = _buf.getDouble(offset);
            v += grad * grad;
            _buf.putDouble(offset, v);
        }

        @Override
        void clear() {
            for (int f = 0; f < _factors; f++) {
                long offset = _gg_offset + SizeOf.DOUBLE * f;
                _buf.putDouble(offset, 0.d);
            }
        }

        static int sizeOf(@Nonnegative final int factors) {
            return Entry.sizeOf(factors) + SizeOf.DOUBLE * factors;
        }

        @Override
        public String toString() {
            final double[] gg = new double[_factors];
            for (int f = 0; f < _factors; f++) {
                gg[f] = getSumOfSquaredGradients(f);
            }
            return super.toString() + ", gg=" + Arrays.toString(gg);
        }

    }

    static final class FTRLEntry extends Entry {

        final long _z_offset;

        FTRLEntry(@Nonnull HeapBuffer buf, int key, long offset) {
            this(buf, 1, key, offset);
        }

        FTRLEntry(@Nonnull HeapBuffer buf, @Nonnegative int factors, int key, long offset) {
            super(buf, factors, FTRLEntry.sizeOf(factors), key, offset);
            this._z_offset = _offset + Entry.sizeOf(factors);
        }

        @Override
        float updateZ(final int f, final float W, final float gradW, final float alpha) {
            Preconditions.checkArgument(f >= 0);

            final long zOffset = offsetZ(f);

            final float z = _buf.getFloat(zOffset);
            final double n = _buf.getFloat(offsetN(f)); // implicit cast to float

            double gg = gradW * gradW;
            float sigma = (float) ((Math.sqrt(n + gg) - Math.sqrt(n)) / alpha);

            final float newZ = z + gradW - sigma * W;
            if (!NumberUtils.isFinite(newZ)) {
                throw new IllegalStateException(
                    "Got newZ " + newZ + " where z=" + z + ", gradW=" + gradW + ", sigma=" + sigma
                            + ", W=" + W + ", n=" + n + ", gg=" + gg + ", alpha=" + alpha);
            }
            _buf.putFloat(zOffset, newZ);
            return newZ;
        }

        @Override
        double updateN(final int f, final float gradW) {
            Preconditions.checkArgument(f >= 0);

            final long nOffset = offsetN(f);

            final double n = _buf.getFloat(nOffset);
            final double newN = n + gradW * gradW;
            if (!NumberUtils.isFinite(newN)) {
                throw new IllegalStateException(
                    "Got newN " + newN + " where n=" + n + ", gradW=" + gradW);
            }
            _buf.putFloat(nOffset, NumberUtils.castToFloat(newN)); // cast may throw ArithmeticException
            return newN;
        }

        private long offsetZ(@Nonnegative final int f) {
            return _z_offset + SizeOf.FLOAT * f;
        }

        private long offsetN(@Nonnegative final int f) {
            return _z_offset + SizeOf.FLOAT * (_factors + f);
        }

        @Override
        void clear() {
            for (int f = 0; f < _factors; f++) {
                _buf.putFloat(offsetZ(f), 0.f);
                _buf.putFloat(offsetN(f), 0.f);
            }
        }

        static int sizeOf(@Nonnegative final int factors) {
            return Entry.sizeOf(factors) + (SizeOf.FLOAT + SizeOf.FLOAT) * factors;
        }

        @Override
        public String toString() {
            final float[] Z = new float[_factors];
            final float[] N = new float[_factors];
            for (int f = 0; f < _factors; f++) {
                Z[f] = _buf.getFloat(offsetZ(f));
                N[f] = _buf.getFloat(offsetN(f));
            }
            return super.toString() + ", Z=" + Arrays.toString(Z) + ", N=" + Arrays.toString(N);
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy