All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.factorization.fm.FFMStringFeatureMapModel Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.factorization.fm;

import hivemall.factorization.fm.Entry.AdaGradEntry;
import hivemall.factorization.fm.Entry.FTRLEntry;
import hivemall.factorization.fm.FMHyperParameters.FFMHyperParameters;
import hivemall.utils.buffer.HeapBuffer;
import hivemall.utils.collections.lists.LongArrayList;
import hivemall.utils.lang.NumberUtils;
import it.unimi.dsi.fastutil.ints.Int2LongMap;
import it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap;

import java.text.NumberFormat;
import java.util.Locale;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.roaringbitmap.RoaringBitmap;

public final class FFMStringFeatureMapModel extends FieldAwareFactorizationMachineModel {
    private static final int DEFAULT_MAPSIZE = 65536;

    // LEARNING PARAMS
    private float _w0;
    @Nonnull
    final Int2LongMap _map;
    @Nonnull
    final HeapBuffer _buf;

    @Nonnull
    private final LongArrayList _freelistW;
    @Nonnull
    private final LongArrayList _freelistV;

    private boolean _initV;
    @Nonnull
    private RoaringBitmap _removedV;

    // hyperparams
    private final int _numFields;

    private final int _entrySizeW;
    private final int _entrySizeV;

    // statistics
    private long _bytesAllocated, _bytesUsed;
    private int _numAllocatedW, _numReusedW, _numRemovedW;
    private int _numAllocatedV, _numReusedV, _numRemovedV;

    public FFMStringFeatureMapModel(@Nonnull FFMHyperParameters params) {
        super(params);
        this._w0 = 0.f;
        this._map = new Int2LongOpenHashMap(DEFAULT_MAPSIZE);
        _map.defaultReturnValue(-1L);
        this._buf = new HeapBuffer(HeapBuffer.DEFAULT_CHUNK_SIZE);
        this._freelistW = new LongArrayList();
        this._freelistV = new LongArrayList();
        this._initV = true;
        this._removedV = new RoaringBitmap();
        this._numFields = params.numFields;
        this._entrySizeW = entrySize(1, _useFTRL, _useAdaGrad);
        this._entrySizeV = entrySize(_factor, _useFTRL, _useAdaGrad);
    }

    private static int entrySize(@Nonnegative int factors, boolean ftrl, boolean adagrad) {
        if (ftrl) {
            return FTRLEntry.sizeOf(factors);
        } else if (adagrad) {
            return AdaGradEntry.sizeOf(factors);
        } else {
            return Entry.sizeOf(factors);
        }
    }

    void disableInitV() {
        this._initV = false;
    }

    @Override
    public int getSize() {
        return _map.size();
    }

    @Override
    public float getW0() {
        return _w0;
    }

    @Override
    protected void setW0(float nextW0) {
        this._w0 = nextW0;
    }

    @Override
    public float getW(@Nonnull final Feature x) {
        int j = Feature.toIntFeature(x);

        Entry entry = getEntry(j);
        if (entry == null) {
            return 0.f;
        }
        return entry.getW();
    }

    @Override
    protected void setW(@Nonnull final Feature x, final float nextWi) {
        final int j = Feature.toIntFeature(x);

        Entry entry = getEntry(j);
        if (entry == null) {
            entry = newEntry(j, nextWi);
            long ptr = entry.getOffset();
            _map.put(j, ptr);
        } else {
            entry.setW(nextWi);
        }
    }

    /**
     * @return V_x,yField,f
     */
    @Override
    public float getV(@Nonnull final Feature x, @Nonnull final int yField, final int f) {
        final int j = Feature.toIntFeature(x, yField, _numFields);

        Entry entry = getEntry(j);
        if (entry == null) {
            if (_initV == false) {
                return 0.f;
            } else if (_removedV.contains(j)) {
                return 0.f;
            }
            float[] V = initV();
            entry = newEntry(j, V);
            long ptr = entry.getOffset();
            _map.put(j, ptr);
            return V[f];
        }
        return entry.getV(f);
    }

    @Override
    protected void setV(@Nonnull final Feature x, @Nonnull final int yField, final int f,
            final float nextVif) {
        final int j = Feature.toIntFeature(x, yField, _numFields);

        Entry entry = getEntry(j);
        if (entry == null) {
            if (_initV == false) {
                return;
            } else if (_removedV.contains(j)) {
                return;
            }
            float[] V = initV();
            entry = newEntry(j, V);
            long ptr = entry.getOffset();
            _map.put(j, ptr);
        }
        entry.setV(f, nextVif);
    }

    @Override
    protected Entry getEntryW(@Nonnull final Feature x) {
        final int j = Feature.toIntFeature(x);

        Entry entry = getEntry(j);
        if (entry == null) {
            entry = newEntry(j, 0.f);
            long ptr = entry.getOffset();
            _map.put(j, ptr);
        }
        return entry;
    }

    @Override
    protected Entry getEntryV(@Nonnull final Feature x, @Nonnull final int yField) {
        final int j = Feature.toIntFeature(x, yField, _numFields);

        Entry entry = getEntry(j);
        if (entry == null) {
            if (_initV == false) {
                return null;
            } else if (_removedV.contains(j)) {
                return null;
            }
            float[] V = initV();
            entry = newEntry(j, V);
            long ptr = entry.getOffset();
            _map.put(j, ptr);
        }
        return entry;
    }

    @Override
    protected void removeEntry(@Nonnull final Entry entry) {
        final int j = entry.getKey();
        final long ptr = _map.remove(j);
        if (ptr == -1L) {
            return; // should never be happen.
        }
        entry.clear();
        if (Entry.isEntryW(j)) {
            _freelistW.add(ptr);
            this._numRemovedW++;
            this._bytesUsed -= _entrySizeW;
        } else {
            _removedV.add(j);
            _freelistV.add(ptr);
            this._numRemovedV++;
            this._bytesUsed -= _entrySizeV;
        }
    }

    @Nonnull
    protected final Entry newEntry(final int key, final float W) {
        final long ptr;
        if (_freelistW.isEmpty()) {
            ptr = _buf.allocate(_entrySizeW);
            this._numAllocatedW++;
            this._bytesAllocated += _entrySizeW;
            this._bytesUsed += _entrySizeW;
        } else {// reuse removed entry
            ptr = _freelistW.remove();
            this._numReusedW++;
        }
        final Entry entry;
        if (_useFTRL) {
            entry = new FTRLEntry(_buf, key, ptr);
        } else if (_useAdaGrad) {
            entry = new AdaGradEntry(_buf, key, ptr);
        } else {
            entry = new Entry(_buf, key, ptr);
        }

        entry.setW(W);
        return entry;
    }

    @Nonnull
    protected final Entry newEntry(final int key, @Nonnull final float[] V) {
        final long ptr;
        if (_freelistV.isEmpty()) {
            ptr = _buf.allocate(_entrySizeV);
            this._numAllocatedV++;
            this._bytesAllocated += _entrySizeV;
            this._bytesUsed += _entrySizeV;
        } else {// reuse removed entry
            ptr = _freelistV.remove();
            this._numReusedV++;
        }
        final Entry entry;
        if (_useFTRL) {
            entry = new FTRLEntry(_buf, _factor, key, ptr);
        } else if (_useAdaGrad) {
            entry = new AdaGradEntry(_buf, _factor, key, ptr);
        } else {
            entry = new Entry(_buf, _factor, key, ptr);
        }

        entry.setV(V);
        return entry;
    }

    @Nullable
    private Entry getEntry(final int key) {
        final long ptr = _map.get(key);
        if (ptr == -1L) {
            return null;
        }
        return getEntry(key, ptr);
    }

    @Nonnull
    private Entry getEntry(final int key, @Nonnegative final long ptr) {
        if (Entry.isEntryW(key)) {
            if (_useFTRL) {
                return new FTRLEntry(_buf, key, ptr);
            } else if (_useAdaGrad) {
                return new AdaGradEntry(_buf, key, ptr);
            } else {
                return new Entry(_buf, key, ptr);
            }
        } else {
            if (_useFTRL) {
                return new FTRLEntry(_buf, _factor, key, ptr);
            } else if (_useAdaGrad) {
                return new AdaGradEntry(_buf, _factor, key, ptr);
            } else {
                return new Entry(_buf, _factor, key, ptr);
            }
        }
    }

    @Nonnull
    String getStatistics() {
        final NumberFormat fmt = NumberFormat.getIntegerInstance(Locale.US);
        return "FFMStringFeatureMapModel [bytesAllocated=" + NumberUtils.prettySize(_bytesAllocated)
                + ", bytesUsed=" + NumberUtils.prettySize(_bytesUsed) + ", numAllocatedW="
                + fmt.format(_numAllocatedW) + ", numReusedW=" + fmt.format(_numReusedW)
                + ", numRemovedW=" + fmt.format(_numRemovedW) + ", numAllocatedV="
                + fmt.format(_numAllocatedV) + ", numReusedV=" + fmt.format(_numReusedV)
                + ", numRemovedV=" + fmt.format(_numRemovedV) + "]";
    }

    @Override
    public String toString() {
        return getStatistics();
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy