Alien-XGBoost/xgboost/dmlc-core/include/dmlc/parameter.h
/*!
* Copyright (c) 2015 by Contributors
* \file parameter.h
* \brief Provide lightweight util to do parameter setup and checking.
*/
#ifndef DMLC_PARAMETER_H_
#define DMLC_PARAMETER_H_
#include <cstddef>
#include <cstdlib>
#include <sstream>
#include <limits>
#include <map>
#include <set>
#include <typeinfo>
#include <string>
#include <vector>
#include <algorithm>
#include <utility>
#include <iostream>
#include "./base.h"
#include "./json.h"
#include "./logging.h"
#include "./type_traits.h"
#include "./optional.h"
namespace dmlc {
// this file is backward compatible with non-c++11
/*! \brief Error throwed by parameter checking */
struct ParamError : public dmlc::Error {
/*!
* \brief constructor
* \param msg error message
*/
explicit ParamError(const std::string &msg)
: dmlc::Error(msg) {}
};
/*!
* \brief Get environment variable with default.
* \param key the name of environment variable.
* \param default_value the default value of environment vriable.
* \return The value received
*/
template<typename ValueType>
inline ValueType GetEnv(const char *key,
ValueType default_value);
/*! \brief internal namespace for parameter manangement */
namespace parameter {
// forward declare ParamManager
class ParamManager;
// forward declare FieldAccessEntry
class FieldAccessEntry;
// forward declare FieldEntry
template<typename DType>
class FieldEntry;
// forward declare ParamManagerSingleton
template<typename PType>
struct ParamManagerSingleton;
/*! \brief option in parameter initialization */
enum ParamInitOption {
/*! \brief allow unknown parameters */
kAllowUnknown,
/*! \brief need to match exact parameters */
kAllMatch,
/*! \brief allow unmatched hidden field with format __*__ */
kAllowHidden
};
} // namespace parameter
/*!
* \brief Information about a parameter field in string representations.
*/
struct ParamFieldInfo {
/*! \brief name of the field */
std::string name;
/*! \brief type of the field in string format */
std::string type;
/*!
* \brief detailed type information string
* This include the default value, enum constran and typename.
*/
std::string type_info_str;
/*! \brief detailed description of the type */
std::string description;
};
/*!
* \brief Parameter is the base type every parameter struct should inheritate from
* The following code is a complete example to setup parameters.
* \code
* struct Param : public dmlc::Parameter<Param> {
* float learning_rate;
* int num_hidden;
* std::string name;
* // declare parameters in header file
* DMLC_DECLARE_PARAMETER(Param) {
* DMLC_DECLARE_FIELD(num_hidden).set_range(0, 1000);
* DMLC_DECLARE_FIELD(learning_rate).set_default(0.01f);
* DMLC_DECLARE_FIELD(name).set_default("hello");
* }
* };
* // register it in cc file
* DMLC_REGISTER_PARAMETER(Param);
* \endcode
*
* After that, the Param struct will get all the functions defined in Parameter.
* \tparam PType the type of parameter struct
*
* \sa DMLC_DECLARE_FIELD, DMLC_REGISTER_PARAMETER, DMLC_DECLARE_PARAMETER
*/
template<typename PType>
struct Parameter {
public:
/*!
* \brief initialize the parameter by keyword arguments.
* This function will initialize the parameter struct, check consistency
* and throw error if something wrong happens.
*
* \param kwargs map of keyword arguments, or vector of pairs
* \parma option The option on initialization.
* \tparam Container container type
* \throw ParamError when something go wrong.
*/
template<typename Container>
inline void Init(const Container &kwargs,
parameter::ParamInitOption option = parameter::kAllowHidden) {
PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(),
NULL,
option);
}
/*!
* \brief initialize the parameter by keyword arguments.
* This is same as Init, but allow unknown arguments.
*
* \param kwargs map of keyword arguments, or vector of pairs
* \tparam Container container type
* \throw ParamError when something go wrong.
* \return vector of pairs of unknown arguments.
*/
template<typename Container>
inline std::vector<std::pair<std::string, std::string> >
InitAllowUnknown(const Container &kwargs) {
std::vector<std::pair<std::string, std::string> > unknown;
PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(),
&unknown, parameter::kAllowUnknown);
return unknown;
}
/*!
* \brief Return a dictionary representation of the parameters
* \return A dictionary that maps key -> value
*/
inline std::map<std::string, std::string> __DICT__() const {
std::vector<std::pair<std::string, std::string> > vec
= PType::__MANAGER__()->GetDict(this->head());
return std::map<std::string, std::string>(vec.begin(), vec.end());
}
/*!
* \brief Write the parameters in JSON format.
* \param writer JSONWriter used for writing.
*/
inline void Save(dmlc::JSONWriter *writer) const {
writer->Write(this->__DICT__());
}
/*!
* \brief Load the parameters from JSON.
* \param reader JSONReader used for loading.
* \throw ParamError when something go wrong.
*/
inline void Load(dmlc::JSONReader *reader) {
std::map<std::string, std::string> kwargs;
reader->Read(&kwargs);
this->Init(kwargs);
}
/*!
* \brief Get the fields of the parameters.
* \return List of ParamFieldInfo of each field.
*/
inline static std::vector<ParamFieldInfo> __FIELDS__() {
return PType::__MANAGER__()->GetFieldInfo();
}
/*!
* \brief Print docstring of the parameter
* \return the printed docstring
*/
inline static std::string __DOC__() {
std::ostringstream os;
PType::__MANAGER__()->PrintDocString(os);
return os.str();
}
protected:
/*!
* \brief internal function to allow declare of a parameter memember
* \param manager the parameter manager
* \param key the key name of the parameter
* \param ref the reference to the parameter in the struct.
*/
template<typename DType>
inline parameter::FieldEntry<DType>& DECLARE(
parameter::ParamManagerSingleton<PType> *manager,
const std::string &key, DType &ref) { // NOLINT(*)
parameter::FieldEntry<DType> *e =
new parameter::FieldEntry<DType>();
e->Init(key, this->head(), ref);
manager->manager.AddEntry(key, e);
return *e;
}
private:
/*! \return Get head pointer of child structure */
inline PType *head() const {
return static_cast<PType*>(const_cast<Parameter<PType>*>(this));
}
};
//! \cond Doxygen_Suppress
/*!
* \brief macro used to declare parameter
*
* Example:
* \code
* struct Param : public dmlc::Parameter<Param> {
* // declare parameters in header file
* DMLC_DECLARE_PARAMETER(Param) {
* // details of declarations
* }
* };
* \endcode
*
* This macro need to be put in a source file so that registeration only happens once.
* Refer to example code in Parameter for details
*
* \param PType the name of parameter struct.
* \sa Parameter
*/
#define DMLC_DECLARE_PARAMETER(PType) \
static ::dmlc::parameter::ParamManager *__MANAGER__(); \
inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<PType> *manager) \
/*!
* \brief macro to declare fields
* \param FieldName the name of the field.
*/
#define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName)
/*!
* \brief macro to declare alias of a fields
* \param FieldName the name of the field.
* \param AliasName the name of the alias, must be declared after the field is declared.
*/
#define DMLC_DECLARE_ALIAS(FieldName, AliasName) manager->manager.AddAlias(#FieldName, #AliasName)
/*!
* \brief Macro used to register parameter.
*
* This macro need to be put in a source file so that registeration only happens once.
* Refer to example code in Parameter for details
* \param PType the type of parameter struct.
* \sa Parameter
*/
#define DMLC_REGISTER_PARAMETER(PType) \
::dmlc::parameter::ParamManager *PType::__MANAGER__() { \
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
return &inst.manager; \
} \
static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
__make__ ## PType ## ParamManager__ = \
(*PType::__MANAGER__()) \
//! \endcond
/*!
* \brief internal namespace for parameter manangement
* There is no need to use it directly in normal case
*/
namespace parameter {
/*!
* \brief FieldAccessEntry interface to help manage the parameters
* Each entry can be used to access one parameter in the Parameter struct.
*
* This is an internal interface used that is used to manage parameters
*/
class FieldAccessEntry {
public:
FieldAccessEntry()
: has_default_(false) {}
/*! \brief destructor */
virtual ~FieldAccessEntry() {}
/*!
* \brief set the default value.
* \param head the pointer to the head of the struct
* \throw error if no default is presented
*/
virtual void SetDefault(void *head) const = 0;
/*!
* \brief set the parameter by string value
* \param head the pointer to the head of the struct
* \param value the value to be set
*/
virtual void Set(void *head, const std::string &value) const = 0;
// check if value is OK
virtual void Check(void *head) const {}
/*!
* \brief get the string representation of value.
* \param head the pointer to the head of the struct
*/
virtual std::string GetStringValue(void *head) const = 0;
/*!
* \brief Get field information
* \return the corresponding field information
*/
virtual ParamFieldInfo GetFieldInfo() const = 0;
protected:
/*! \brief whether this parameter have default value */
bool has_default_;
/*! \brief positional index of parameter in struct */
size_t index_;
/*! \brief parameter key name */
std::string key_;
/*! \brief parameter type */
std::string type_;
/*! \brief description of the parameter */
std::string description_;
/*!
* \brief print string representation of default value
* \parma os the stream to print the docstring to.
*/
virtual void PrintDefaultValueString(std::ostream &os) const = 0; // NOLINT(*)
// allow ParamManager to modify self
friend class ParamManager;
};
/*!
* \brief manager class to handle parameter structure for each type
* An manager will be created for each parameter structure.
*/
class ParamManager {
public:
/*! \brief destructor */
~ParamManager() {
for (size_t i = 0; i < entry_.size(); ++i) {
delete entry_[i];
}
}
/*!
* \brief find the access entry by parameter key
* \param key the key of the parameter.
* \return pointer to FieldAccessEntry, NULL if nothing is found.
*/
inline FieldAccessEntry *Find(const std::string &key) const {
std::map<std::string, FieldAccessEntry*>::const_iterator it =
entry_map_.find(key);
if (it == entry_map_.end()) return NULL;
return it->second;
}
/*!
* \brief set parameter by keyword arguments.
* \param head head to the parameter field.
* \param begin begin iterator of original kwargs
* \param end end iterator of original kwargs
* \param unknown_args optional, used to hold unknown arguments
* When it is specified, unknown arguments will be stored into here, instead of raise an error
* \tparam RandomAccessIterator iterator type
* \throw ParamError when there is unknown argument and unknown_args == NULL, or required argument is missing.
*/
template<typename RandomAccessIterator>
inline void RunInit(void *head,
RandomAccessIterator begin,
RandomAccessIterator end,
std::vector<std::pair<std::string, std::string> > *unknown_args,
parameter::ParamInitOption option) const {
std::set<FieldAccessEntry*> selected_args;
for (RandomAccessIterator it = begin; it != end; ++it) {
FieldAccessEntry *e = Find(it->first);
if (e != NULL) {
e->Set(head, it->second);
e->Check(head);
selected_args.insert(e);
} else {
if (unknown_args != NULL) {
unknown_args->push_back(*it);
} else {
if (option != parameter::kAllowUnknown) {
if (option == parameter::kAllowHidden &&
it->first.length() > 4 &&
it->first.find("__") == 0 &&
it->first.rfind("__") == it->first.length()-2) {
continue;
}
std::ostringstream os;
os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n";
os << "----------------\n";
PrintDocString(os);
throw dmlc::ParamError(os.str());
}
}
}
}
for (std::map<std::string, FieldAccessEntry*>::const_iterator it = entry_map_.begin();
it != entry_map_.end(); ++it) {
if (selected_args.count(it->second) == 0) {
it->second->SetDefault(head);
}
}
}
/*!
* \brief internal function to add entry to manager,
* The manager will take ownership of the entry.
* \param key the key to the parameters
* \param e the pointer to the new entry.
*/
inline void AddEntry(const std::string &key, FieldAccessEntry *e) {
e->index_ = entry_.size();
// TODO(bing) better error message
if (entry_map_.count(key) != 0) {
LOG(FATAL) << "key " << key << " has already been registered in " << name_;
}
entry_.push_back(e);
entry_map_[key] = e;
}
/*!
* \brief internal function to add entry to manager,
* The manager will take ownership of the entry.
* \param key the key to the parameters
* \param e the pointer to the new entry.
*/
inline void AddAlias(const std::string& field, const std::string& alias) {
if (entry_map_.count(field) == 0) {
LOG(FATAL) << "key " << field << " has not been registered in " << name_;
}
if (entry_map_.count(alias) != 0) {
LOG(FATAL) << "Alias " << alias << " has already been registered in " << name_;
}
entry_map_[alias] = entry_map_[field];
}
/*!
* \brief set the name of parameter manager
* \param name the name to set
*/
inline void set_name(const std::string &name) {
name_ = name;
}
/*!
* \brief get field information of each field.
* \return field information
*/
inline std::vector<ParamFieldInfo> GetFieldInfo() const {
std::vector<ParamFieldInfo> ret(entry_.size());
for (size_t i = 0; i < entry_.size(); ++i) {
ret[i] = entry_[i]->GetFieldInfo();
}
return ret;
}
/*!
* \brief Print readible docstring to ostream, add newline.
* \parma os the stream to print the docstring to.
*/
inline void PrintDocString(std::ostream &os) const { // NOLINT(*)
for (size_t i = 0; i < entry_.size(); ++i) {
ParamFieldInfo info = entry_[i]->GetFieldInfo();
os << info.name << " : " << info.type_info_str << '\n';
if (info.description.length() != 0) {
os << " " << info.description << '\n';
}
}
}
/*!
* \brief Get internal parameters in vector of pairs.
* \param head the head of the struct.
* \param skip_default skip the values that equals default value.
* \return the parameter dictionary.
*/
inline std::vector<std::pair<std::string, std::string> > GetDict(void * head) const {
std::vector<std::pair<std::string, std::string> > ret;
for (std::map<std::string, FieldAccessEntry*>::const_iterator
it = entry_map_.begin(); it != entry_map_.end(); ++it) {
ret.push_back(std::make_pair(it->first, it->second->GetStringValue(head)));
}
return ret;
}
private:
/*! \brief parameter struct name */
std::string name_;
/*! \brief positional list of entries */
std::vector<FieldAccessEntry*> entry_;
/*! \brief map from key to entry */
std::map<std::string, FieldAccessEntry*> entry_map_;
};
//! \cond Doxygen_Suppress
// The following piece of code will be template heavy and less documented
// singleton parameter manager for certain type, used for initialization
template<typename PType>
struct ParamManagerSingleton {
ParamManager manager;
explicit ParamManagerSingleton(const std::string ¶m_name) {
PType param;
param.__DECLARE__(this);
manager.set_name(param_name);
}
};
// Base class of FieldEntry
// implement set_default
template<typename TEntry, typename DType>
class FieldEntryBase : public FieldAccessEntry {
public:
// entry type
typedef TEntry EntryType;
// implement set value
virtual void Set(void *head, const std::string &value) const {
std::istringstream is(value);
is >> this->Get(head);
if (!is.fail()) {
while (!is.eof()) {
int ch = is.get();
if (ch == EOF) {
is.clear(); break;
}
if (!isspace(ch)) {
is.setstate(std::ios::failbit); break;
}
}
}
if (is.fail()) {
std::ostringstream os;
os << "Invalid Parameter format for " << key_
<< " expect " << type_ << " but value=\'" << value<< '\'';
throw dmlc::ParamError(os.str());
}
}
virtual std::string GetStringValue(void *head) const {
std::ostringstream os;
PrintValue(os, this->Get(head));
return os.str();
}
virtual ParamFieldInfo GetFieldInfo() const {
ParamFieldInfo info;
std::ostringstream os;
info.name = key_;
info.type = type_;
os << type_;
if (has_default_) {
os << ',' << " optional, default=";
PrintDefaultValueString(os);
} else {
os << ", required";
}
info.type_info_str = os.str();
info.description = description_;
return info;
}
// implement set head to default value
virtual void SetDefault(void *head) const {
if (!has_default_) {
std::ostringstream os;
os << "Required parameter " << key_
<< " of " << type_ << " is not presented";
throw dmlc::ParamError(os.str());
} else {
this->Get(head) = default_value_;
}
}
// return reference of self as derived type
inline TEntry &self() {
return *(static_cast<TEntry*>(this));
}
// implement set_default
inline TEntry &set_default(const DType &default_value) {
default_value_ = default_value;
has_default_ = true;
// return self to allow chaining
return this->self();
}
// implement describe
inline TEntry &describe(const std::string &description) {
description_ = description;
// return self to allow chaining
return this->self();
}
// initialization function
inline void Init(const std::string &key,
void *head, DType &ref) { // NOLINT(*)
this->key_ = key;
if (this->type_.length() == 0) {
this->type_ = dmlc::type_name<DType>();
}
this->offset_ = ((char*)&ref) - ((char*)head); // NOLINT(*)
}
protected:
// print the value
virtual void PrintValue(std::ostream &os, DType value) const { // NOLINT(*)
os << value;
}
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
PrintValue(os, default_value_);
}
// get the internal representation of parameter
// for example if this entry corresponds field param.learning_rate
// then Get(¶m) will return reference to param.learning_rate
inline DType &Get(void *head) const {
return *(DType*)((char*)(head) + offset_); // NOLINT(*)
}
// internal offset of the field
ptrdiff_t offset_;
// default value of field
DType default_value_;
};
// parameter base for numeric types that have range
template<typename TEntry, typename DType>
class FieldEntryNumeric
: public FieldEntryBase<TEntry, DType> {
public:
FieldEntryNumeric()
: has_begin_(false), has_end_(false) {}
// implement set_range
virtual TEntry &set_range(DType begin, DType end) {
begin_ = begin; end_ = end;
has_begin_ = true; has_end_ = true;
return this->self();
}
// implement set_range
virtual TEntry &set_lower_bound(DType begin) {
begin_ = begin; has_begin_ = true;
return this->self();
}
// consistency check for numeric ranges
virtual void Check(void *head) const {
FieldEntryBase<TEntry, DType>::Check(head);
DType v = this->Get(head);
if (has_begin_ && has_end_) {
if (v < begin_ || v > end_) {
std::ostringstream os;
os << "value " << v << " for Parameter " << this->key_
<< " exceed bound [" << begin_ << ',' << end_ <<']';
throw dmlc::ParamError(os.str());
}
} else if (has_begin_ && v < begin_) {
std::ostringstream os;
os << "value " << v << " for Parameter " << this->key_
<< " should be greater equal to " << begin_;
throw dmlc::ParamError(os.str());
} else if (has_end_ && v > end_) {
std::ostringstream os;
os << "value " << v << " for Parameter " << this->key_
<< " should be smaller equal to " << end_;
throw dmlc::ParamError(os.str());
}
}
protected:
// whether it have begin and end range
bool has_begin_, has_end_;
// data bound
DType begin_, end_;
};
/*!
* \brief FieldEntry defines parsing and checking behavior of DType.
* This class can be specialized to implement specific behavior of more settings.
* \tparam DType the data type of the entry.
*/
template<typename DType>
class FieldEntry :
public IfThenElseType<dmlc::is_arithmetic<DType>::value,
FieldEntryNumeric<FieldEntry<DType>, DType>,
FieldEntryBase<FieldEntry<DType>, DType> >::Type {
};
// specialize define for int(enum)
template<>
class FieldEntry<int>
: public FieldEntryNumeric<FieldEntry<int>, int> {
public:
// construct
FieldEntry<int>() : is_enum_(false) {}
// parent
typedef FieldEntryNumeric<FieldEntry<int>, int> Parent;
// override set
virtual void Set(void *head, const std::string &value) const {
if (is_enum_) {
std::map<std::string, int>::const_iterator it = enum_map_.find(value);
std::ostringstream os;
if (it == enum_map_.end()) {
os << "Invalid Input: \'" << value;
os << "\', valid values are: ";
PrintEnums(os);
throw dmlc::ParamError(os.str());
} else {
os << it->second;
Parent::Set(head, os.str());
}
} else {
Parent::Set(head, value);
}
}
virtual ParamFieldInfo GetFieldInfo() const {
if (is_enum_) {
ParamFieldInfo info;
std::ostringstream os;
info.name = key_;
info.type = type_;
PrintEnums(os);
if (has_default_) {
os << ',' << "optional, default=";
PrintDefaultValueString(os);
} else {
os << ", required";
}
info.type_info_str = os.str();
info.description = description_;
return info;
} else {
return Parent::GetFieldInfo();
}
}
// add enum
inline FieldEntry<int> &add_enum(const std::string &key, int value) {
if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
enum_back_map_.count(value) != 0) {
std::ostringstream os;
os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n";
os << "Enums: ";
for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
it != enum_map_.end(); ++it) {
os << "(" << it->first << ": " << it->second << "), ";
}
throw dmlc::ParamError(os.str());
}
enum_map_[key] = value;
enum_back_map_[value] = key;
is_enum_ = true;
return this->self();
}
protected:
// enum flag
bool is_enum_;
// enum map
std::map<std::string, int> enum_map_;
// enum map
std::map<int, std::string> enum_back_map_;
// override print behavior
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
os << '\'';
PrintValue(os, default_value_);
os << '\'';
}
// override print default
virtual void PrintValue(std::ostream &os, int value) const { // NOLINT(*)
if (is_enum_) {
CHECK_NE(enum_back_map_.count(value), 0U)
<< "Value not found in enum declared";
os << enum_back_map_.at(value);
} else {
os << value;
}
}
private:
inline void PrintEnums(std::ostream &os) const { // NOLINT(*)
os << '{';
for (std::map<std::string, int>::const_iterator
it = enum_map_.begin(); it != enum_map_.end(); ++it) {
if (it != enum_map_.begin()) {
os << ", ";
}
os << "\'" << it->first << '\'';
}
os << '}';
}
};
// specialize define for optional<int>(enum)
template<>
class FieldEntry<optional<int> >
: public FieldEntryBase<FieldEntry<optional<int> >, optional<int> > {
public:
// construct
FieldEntry<optional<int> >() : is_enum_(false) {}
// parent
typedef FieldEntryBase<FieldEntry<optional<int> >, optional<int> > Parent;
// override set
virtual void Set(void *head, const std::string &value) const {
if (is_enum_ && value != "None") {
std::map<std::string, int>::const_iterator it = enum_map_.find(value);
std::ostringstream os;
if (it == enum_map_.end()) {
os << "Invalid Input: \'" << value;
os << "\', valid values are: ";
PrintEnums(os);
throw dmlc::ParamError(os.str());
} else {
os << it->second;
Parent::Set(head, os.str());
}
} else {
Parent::Set(head, value);
}
}
virtual ParamFieldInfo GetFieldInfo() const {
if (is_enum_) {
ParamFieldInfo info;
std::ostringstream os;
info.name = key_;
info.type = type_;
PrintEnums(os);
if (has_default_) {
os << ',' << "optional, default=";
PrintDefaultValueString(os);
} else {
os << ", required";
}
info.type_info_str = os.str();
info.description = description_;
return info;
} else {
return Parent::GetFieldInfo();
}
}
// add enum
inline FieldEntry<optional<int> > &add_enum(const std::string &key, int value) {
CHECK_NE(key, "None") << "None is reserved for empty optional<int>";
if ((enum_map_.size() != 0 && enum_map_.count(key) != 0) || \
enum_back_map_.count(value) != 0) {
std::ostringstream os;
os << "Enum " << "(" << key << ": " << value << " exisit!" << ")\n";
os << "Enums: ";
for (std::map<std::string, int>::const_iterator it = enum_map_.begin();
it != enum_map_.end(); ++it) {
os << "(" << it->first << ": " << it->second << "), ";
}
throw dmlc::ParamError(os.str());
}
enum_map_[key] = value;
enum_back_map_[value] = key;
is_enum_ = true;
return this->self();
}
protected:
// enum flag
bool is_enum_;
// enum map
std::map<std::string, int> enum_map_;
// enum map
std::map<int, std::string> enum_back_map_;
// override print behavior
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
os << '\'';
PrintValue(os, default_value_);
os << '\'';
}
// override print default
virtual void PrintValue(std::ostream &os, optional<int> value) const { // NOLINT(*)
if (is_enum_) {
if (!value) {
os << "None";
} else {
CHECK_NE(enum_back_map_.count(value.value()), 0U)
<< "Value not found in enum declared";
os << enum_back_map_.at(value.value());
}
} else {
os << value;
}
}
private:
inline void PrintEnums(std::ostream &os) const { // NOLINT(*)
os << "{None";
for (std::map<std::string, int>::const_iterator
it = enum_map_.begin(); it != enum_map_.end(); ++it) {
os << ", ";
os << "\'" << it->first << '\'';
}
os << '}';
}
};
// specialize define for string
template<>
class FieldEntry<std::string>
: public FieldEntryBase<FieldEntry<std::string>, std::string> {
public:
// parent class
typedef FieldEntryBase<FieldEntry<std::string>, std::string> Parent;
// override set
virtual void Set(void *head, const std::string &value) const {
this->Get(head) = value;
}
// override print default
virtual void PrintDefaultValueString(std::ostream &os) const { // NOLINT(*)
os << '\'' << default_value_ << '\'';
}
};
// specialize define for bool
template<>
class FieldEntry<bool>
: public FieldEntryBase<FieldEntry<bool>, bool> {
public:
// parent class
typedef FieldEntryBase<FieldEntry<bool>, bool> Parent;
// override set
virtual void Set(void *head, const std::string &value) const {
std::string lower_case; lower_case.resize(value.length());
std::transform(value.begin(), value.end(), lower_case.begin(), ::tolower);
bool &ref = this->Get(head);
if (lower_case == "true") {
ref = true;
} else if (lower_case == "false") {
ref = false;
} else if (lower_case == "1") {
ref = true;
} else if (lower_case == "0") {
ref = false;
} else {
std::ostringstream os;
os << "Invalid Parameter format for " << key_
<< " expect " << type_ << " but value=\'" << value<< '\'';
throw dmlc::ParamError(os.str());
}
}
protected:
// print default string
virtual void PrintValue(std::ostream &os, bool value) const { // NOLINT(*)
if (value) {
os << "True";
} else {
os << "False";
}
}
};
// specialize define for float. Uses stof for platform independent handling of
// INF, -INF, NAN, etc.
#if DMLC_USE_CXX11
template <>
class FieldEntry<float> : public FieldEntryNumeric<FieldEntry<float>, float> {
public:
// parent
typedef FieldEntryNumeric<FieldEntry<float>, float> Parent;
// override set
virtual void Set(void *head, const std::string &value) const {
try {
this->Get(head) = std::stof(value);
} catch (const std::invalid_argument &) {
std::ostringstream os;
os << "Invalid Parameter format for " << key_ << " expect " << type_
<< " but value=\'" << value << '\'';
throw dmlc::ParamError(os.str());
}
}
};
// specialize define for double. Uses stod for platform independent handling of
// INF, -INF, NAN, etc.
template <>
class FieldEntry<double>
: public FieldEntryNumeric<FieldEntry<double>, double> {
public:
// parent
typedef FieldEntryNumeric<FieldEntry<double>, double> Parent;
// override set
virtual void Set(void *head, const std::string &value) const {
try {
this->Get(head) = std::stod(value);
} catch (const std::invalid_argument &) {
std::ostringstream os;
os << "Invalid Parameter format for " << key_ << " expect " << type_
<< " but value=\'" << value << '\'';
throw dmlc::ParamError(os.str());
}
}
};
#endif // DMLC_USE_CXX11
} // namespace parameter
//! \endcond
// implement GetEnv
template<typename ValueType>
inline ValueType GetEnv(const char *key,
ValueType default_value) {
const char *val = getenv(key);
if (val == NULL) return default_value;
ValueType ret;
parameter::FieldEntry<ValueType> e;
e.Init(key, &ret, ret);
e.Set(&ret, val);
return ret;
}
} // namespace dmlc
#endif // DMLC_PARAMETER_H_