| [0b990d] | 1 | //
 | 
|---|
 | 2 | // opt.cc
 | 
|---|
 | 3 | //
 | 
|---|
 | 4 | // Copyright (C) 1996 Limit Point Systems, Inc.
 | 
|---|
 | 5 | //
 | 
|---|
 | 6 | // Author: Curtis Janssen <cljanss@limitpt.com>
 | 
|---|
 | 7 | // Maintainer: LPS
 | 
|---|
 | 8 | //
 | 
|---|
 | 9 | // This file is part of the SC Toolkit.
 | 
|---|
 | 10 | //
 | 
|---|
 | 11 | // The SC Toolkit is free software; you can redistribute it and/or modify
 | 
|---|
 | 12 | // it under the terms of the GNU Library General Public License as published by
 | 
|---|
 | 13 | // the Free Software Foundation; either version 2, or (at your option)
 | 
|---|
 | 14 | // any later version.
 | 
|---|
 | 15 | //
 | 
|---|
 | 16 | // The SC Toolkit is distributed in the hope that it will be useful,
 | 
|---|
 | 17 | // but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
|---|
 | 18 | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
|---|
 | 19 | // GNU Library General Public License for more details.
 | 
|---|
 | 20 | //
 | 
|---|
 | 21 | // You should have received a copy of the GNU Library General Public License
 | 
|---|
 | 22 | // along with the SC Toolkit; see the file COPYING.LIB.  If not, write to
 | 
|---|
 | 23 | // the Free Software Foundation, 675 Mass Ave, Cambridge, MA 02139, USA.
 | 
|---|
 | 24 | //
 | 
|---|
 | 25 | // The U.S. Government is granted a limited license as per AL 91-7.
 | 
|---|
 | 26 | //
 | 
|---|
 | 27 | 
 | 
|---|
 | 28 | #ifdef __GNUC__
 | 
|---|
 | 29 | #pragma implementation
 | 
|---|
 | 30 | #endif
 | 
|---|
 | 31 | 
 | 
|---|
 | 32 | #include <math.h>
 | 
|---|
 | 33 | #include <deque>
 | 
|---|
 | 34 | 
 | 
|---|
 | 35 | #include <math/optimize/opt.h>
 | 
|---|
 | 36 | #include <util/keyval/keyval.h>
 | 
|---|
 | 37 | #include <util/misc/formio.h>
 | 
|---|
 | 38 | #include <util/misc/timer.h>
 | 
|---|
 | 39 | #include <util/state/stateio.h>
 | 
|---|
 | 40 | #include <util/state/state_bin.h>
 | 
|---|
 | 41 | 
 | 
|---|
 | 42 | using namespace std;
 | 
|---|
 | 43 | using namespace sc;
 | 
|---|
 | 44 | 
 | 
|---|
 | 45 | /////////////////////////////////////////////////////////////////////////
 | 
|---|
 | 46 | // Optimize
 | 
|---|
 | 47 | 
 | 
|---|
 | 48 | static ClassDesc Optimize_cd(
 | 
|---|
 | 49 |   typeid(Optimize),"Optimize",2,"virtual public SavableState",
 | 
|---|
 | 50 |   0, 0, 0);
 | 
|---|
 | 51 | 
 | 
|---|
 | 52 | Optimize::Optimize() :
 | 
|---|
 | 53 |   ckpt_(0), ckpt_file(0)
 | 
|---|
 | 54 | {
 | 
|---|
 | 55 | }
 | 
|---|
 | 56 | 
 | 
|---|
 | 57 | Optimize::Optimize(StateIn&s):
 | 
|---|
 | 58 |   SavableState(s)
 | 
|---|
 | 59 | {
 | 
|---|
 | 60 |   s.get(ckpt_,"checkpoint");
 | 
|---|
 | 61 |   s.getstring(ckpt_file);
 | 
|---|
 | 62 |   s.get(max_iterations_,"max_iterations");
 | 
|---|
 | 63 |   s.get(max_stepsize_,"max_stepsize");
 | 
|---|
 | 64 |   if (s.version(::class_desc<Optimize>()) > 1) {
 | 
|---|
 | 65 |       s.get(print_timings_,"print_timings");
 | 
|---|
 | 66 |     }
 | 
|---|
 | 67 |   n_iterations_ = 0;
 | 
|---|
 | 68 |   conv_ << SavableState::restore_state(s);
 | 
|---|
 | 69 |   function_ << SavableState::key_restore_state(s,"function");
 | 
|---|
 | 70 | }
 | 
|---|
 | 71 | 
 | 
|---|
 | 72 | Optimize::Optimize(const Ref<KeyVal>&keyval)
 | 
|---|
 | 73 | {
 | 
|---|
 | 74 |   print_timings_ = keyval->booleanvalue("print_timings");
 | 
|---|
 | 75 |   if (keyval->error() != KeyVal::OK) print_timings_ = 0;
 | 
|---|
 | 76 |   ckpt_ = keyval->booleanvalue("checkpoint");
 | 
|---|
 | 77 |   if (keyval->error() != KeyVal::OK) ckpt_ = 0;
 | 
|---|
 | 78 |   ckpt_file = keyval->pcharvalue("checkpoint_file");
 | 
|---|
 | 79 |   if (keyval->error() != KeyVal::OK) {
 | 
|---|
 | 80 |     ckpt_file = new char[13];
 | 
|---|
 | 81 |     strcpy(ckpt_file,"opt_ckpt.dat");
 | 
|---|
 | 82 |   }
 | 
|---|
 | 83 | 
 | 
|---|
 | 84 |   max_iterations_ = keyval->intvalue("max_iterations");
 | 
|---|
 | 85 |   if (keyval->error() != KeyVal::OK) max_iterations_ = 10;
 | 
|---|
 | 86 |   n_iterations_ = 0;
 | 
|---|
 | 87 | 
 | 
|---|
 | 88 |   max_stepsize_ = keyval->doublevalue("max_stepsize");
 | 
|---|
 | 89 |   if (keyval->error() != KeyVal::OK) max_stepsize_ = 0.6;
 | 
|---|
 | 90 | 
 | 
|---|
 | 91 |   function_ << keyval->describedclassvalue("function");
 | 
|---|
 | 92 | //  if (function_.null()) {
 | 
|---|
 | 93 | //      ExEnv::err0() << "Optimize requires a function keyword" << endl;
 | 
|---|
 | 94 | //      ExEnv::err0() << "which is an object of type Function" << endl;
 | 
|---|
 | 95 | //      abort();
 | 
|---|
 | 96 | //    }
 | 
|---|
 | 97 | // can't assume lineopt's have a function keyword
 | 
|---|
 | 98 | 
 | 
|---|
 | 99 |   conv_ << keyval->describedclassvalue("convergence");
 | 
|---|
 | 100 |   if (conv_.null()) {
 | 
|---|
 | 101 |       double convergence = keyval->doublevalue("convergence");
 | 
|---|
 | 102 |       if (keyval->error() == KeyVal::OK) {
 | 
|---|
 | 103 |           conv_ = new Convergence(convergence);
 | 
|---|
 | 104 |         }
 | 
|---|
 | 105 |     }
 | 
|---|
 | 106 |   if (conv_.null()) conv_ = new Convergence();
 | 
|---|
 | 107 | }
 | 
|---|
 | 108 | 
 | 
|---|
 | 109 | Optimize::~Optimize()
 | 
|---|
 | 110 | {
 | 
|---|
 | 111 |   if (ckpt_file) delete[] ckpt_file;
 | 
|---|
 | 112 |   ckpt_file=0;
 | 
|---|
 | 113 | }
 | 
|---|
 | 114 | 
 | 
|---|
 | 115 | void
 | 
|---|
 | 116 | Optimize::save_data_state(StateOut&s)
 | 
|---|
 | 117 | {
 | 
|---|
 | 118 |   s.put(ckpt_);
 | 
|---|
 | 119 |   s.putstring(ckpt_file);
 | 
|---|
 | 120 |   s.put(max_iterations_);
 | 
|---|
 | 121 |   s.put(max_stepsize_);
 | 
|---|
 | 122 |   s.put(print_timings_);
 | 
|---|
 | 123 |   SavableState::save_state(conv_.pointer(),s);
 | 
|---|
 | 124 |   SavableState::save_state(function_.pointer(),s);
 | 
|---|
 | 125 | }
 | 
|---|
 | 126 | 
 | 
|---|
 | 127 | void
 | 
|---|
 | 128 | Optimize::init()
 | 
|---|
 | 129 | {
 | 
|---|
 | 130 |   n_iterations_ = 0;
 | 
|---|
 | 131 | }
 | 
|---|
 | 132 | 
 | 
|---|
 | 133 | void
 | 
|---|
 | 134 | Optimize::set_checkpoint()
 | 
|---|
 | 135 | {
 | 
|---|
 | 136 |   ckpt_=1;
 | 
|---|
 | 137 | }
 | 
|---|
 | 138 | 
 | 
|---|
 | 139 | void
 | 
|---|
 | 140 | Optimize::set_max_iterations(int mi)
 | 
|---|
 | 141 | {
 | 
|---|
 | 142 |   max_iterations_ = mi;
 | 
|---|
 | 143 | }
 | 
|---|
 | 144 | 
 | 
|---|
 | 145 | void
 | 
|---|
 | 146 | Optimize::set_checkpoint_file(const char *path)
 | 
|---|
 | 147 | {
 | 
|---|
 | 148 |   if (ckpt_file) delete[] ckpt_file;
 | 
|---|
 | 149 |   if (path) {
 | 
|---|
 | 150 |     ckpt_file = new char[strlen(path)+1];
 | 
|---|
 | 151 |     strcpy(ckpt_file,path);
 | 
|---|
 | 152 |   } else
 | 
|---|
 | 153 |     ckpt_file=0;
 | 
|---|
 | 154 | }
 | 
|---|
 | 155 |   
 | 
|---|
 | 156 | void
 | 
|---|
 | 157 | Optimize::set_function(const Ref<Function>& f)
 | 
|---|
 | 158 | {
 | 
|---|
 | 159 |   function_ = f;
 | 
|---|
 | 160 | }
 | 
|---|
 | 161 | 
 | 
|---|
 | 162 | #ifndef OPTSTATEOUT
 | 
|---|
 | 163 | #define OPTSTATEOUT StateOutBin
 | 
|---|
 | 164 | #endif
 | 
|---|
 | 165 | 
 | 
|---|
 | 166 | int
 | 
|---|
 | 167 | Optimize::optimize()
 | 
|---|
 | 168 | {
 | 
|---|
 | 169 |   int result=0;
 | 
|---|
 | 170 |   while((n_iterations_ < max_iterations_) && (!(result = update()))) {
 | 
|---|
 | 171 |       ++n_iterations_;
 | 
|---|
 | 172 |       if (ckpt_) {
 | 
|---|
 | 173 |         OPTSTATEOUT so(ckpt_file);
 | 
|---|
 | 174 |         this->save_state(so);
 | 
|---|
 | 175 |       }
 | 
|---|
 | 176 |       if (print_timings_) {
 | 
|---|
 | 177 |           tim_print(0);
 | 
|---|
 | 178 |         }
 | 
|---|
 | 179 |     }
 | 
|---|
 | 180 |   return result;
 | 
|---|
 | 181 | }
 | 
|---|
 | 182 | 
 | 
|---|
 | 183 | void
 | 
|---|
 | 184 | Optimize::apply_transform(const Ref<NonlinearTransform> &t)
 | 
|---|
 | 185 | {
 | 
|---|
 | 186 | }
 | 
|---|
 | 187 | 
 | 
|---|
 | 188 | /////////////////////////////////////////////////////////////////////////
 | 
|---|
 | 189 | // LineOpt
 | 
|---|
 | 190 | 
 | 
|---|
 | 191 | static ClassDesc LineOpt_cd(
 | 
|---|
 | 192 |   typeid(LineOpt),"LineOpt",1,"public Optimize",
 | 
|---|
 | 193 |   0, 0, 0);
 | 
|---|
 | 194 | 
 | 
|---|
 | 195 | LineOpt::LineOpt(StateIn&s):
 | 
|---|
 | 196 |   Optimize(s), SavableState(s)
 | 
|---|
 | 197 | {
 | 
|---|
 | 198 |   search_direction_ = matrixkit()->vector(dimension());
 | 
|---|
 | 199 |   search_direction_.restore(s);
 | 
|---|
 | 200 | }
 | 
|---|
 | 201 | 
 | 
|---|
 | 202 | LineOpt::LineOpt(const Ref<KeyVal>&keyval)
 | 
|---|
 | 203 | {
 | 
|---|
 | 204 |   decrease_factor_ = keyval->doublevalue("decrease_factor");    
 | 
|---|
 | 205 |   if (keyval->error() != KeyVal::OK) decrease_factor_ = 0.1;
 | 
|---|
 | 206 | }
 | 
|---|
 | 207 | 
 | 
|---|
 | 208 | LineOpt::~LineOpt()
 | 
|---|
 | 209 | {
 | 
|---|
 | 210 | }
 | 
|---|
 | 211 | 
 | 
|---|
 | 212 | void
 | 
|---|
 | 213 | LineOpt::save_data_state(StateOut&s)
 | 
|---|
 | 214 | {
 | 
|---|
 | 215 |   search_direction_.save(s);
 | 
|---|
 | 216 | }
 | 
|---|
 | 217 | 
 | 
|---|
 | 218 | void
 | 
|---|
 | 219 | LineOpt::init(RefSCVector& direction)
 | 
|---|
 | 220 | {
 | 
|---|
 | 221 |   if (function().null()) {
 | 
|---|
 | 222 |       ExEnv::err0() << "LineOpt requires a function object through" << endl;
 | 
|---|
 | 223 |       ExEnv::err0() << "constructor or init method" << endl;
 | 
|---|
 | 224 |       abort();
 | 
|---|
 | 225 |   }
 | 
|---|
 | 226 |   search_direction_ = direction.copy();
 | 
|---|
 | 227 |   initial_x_ = function()->get_x();
 | 
|---|
 | 228 |   initial_value_ = function()->value();
 | 
|---|
 | 229 |   initial_grad_ = function()->gradient();
 | 
|---|
 | 230 |   Optimize::init();
 | 
|---|
 | 231 | }
 | 
|---|
 | 232 | 
 | 
|---|
 | 233 | void
 | 
|---|
 | 234 | LineOpt::init(RefSCVector& direction, Ref<Function> function )
 | 
|---|
 | 235 | {
 | 
|---|
 | 236 |   set_function(function);
 | 
|---|
 | 237 |   init(direction);
 | 
|---|
 | 238 | }
 | 
|---|
 | 239 | 
 | 
|---|
 | 240 | int
 | 
|---|
 | 241 | LineOpt::sufficient_decrease(RefSCVector& step) {
 | 
|---|
 | 242 | 
 | 
|---|
 | 243 |   double ftarget = initial_value_ + decrease_factor_ *
 | 
|---|
 | 244 |     initial_grad_.scalar_product(step);
 | 
|---|
 | 245 |   
 | 
|---|
 | 246 |   RefSCVector xnext = initial_x_ + step;
 | 
|---|
 | 247 |   function()->set_x(xnext);
 | 
|---|
 | 248 |   Ref<NonlinearTransform> t = function()->change_coordinates();
 | 
|---|
 | 249 |   apply_transform(t);
 | 
|---|
 | 250 | 
 | 
|---|
 | 251 |   return function()->value() <= ftarget;
 | 
|---|
 | 252 | }
 | 
|---|
 | 253 | 
 | 
|---|
 | 254 | void
 | 
|---|
 | 255 | LineOpt::apply_transform(const Ref<NonlinearTransform> &t)
 | 
|---|
 | 256 | {
 | 
|---|
 | 257 |   if (t.null()) return;
 | 
|---|
 | 258 |   apply_transform(t);
 | 
|---|
 | 259 |   t->transform_gradient(search_direction_);
 | 
|---|
 | 260 | }
 | 
|---|
 | 261 | 
 | 
|---|
 | 262 | /////////////////////////////////////////////////////////////////////////
 | 
|---|
 | 263 | // Backtrack
 | 
|---|
 | 264 | 
 | 
|---|
 | 265 | static ClassDesc Backtrack_cd(
 | 
|---|
 | 266 |   typeid(Backtrack),"Backtrack",1,"public LineOpt",
 | 
|---|
 | 267 |   0, create<Backtrack>, 0);
 | 
|---|
 | 268 | 
 | 
|---|
 | 269 | Backtrack::Backtrack(const Ref<KeyVal>& keyval) 
 | 
|---|
 | 270 |   : LineOpt(keyval)
 | 
|---|
 | 271 | { 
 | 
|---|
 | 272 |   backtrack_factor_ = keyval->doublevalue("backtrack_factor");
 | 
|---|
 | 273 |   if (keyval->error() != KeyVal::OK) backtrack_factor_ = 0.1;
 | 
|---|
 | 274 | }
 | 
|---|
 | 275 | 
 | 
|---|
 | 276 | int
 | 
|---|
 | 277 | Backtrack::update() {
 | 
|---|
 | 278 |   
 | 
|---|
 | 279 |   deque<double> values;
 | 
|---|
 | 280 |   int acceptable=0;
 | 
|---|
 | 281 |   int descent=1;
 | 
|---|
 | 282 |   int took_step=0;
 | 
|---|
 | 283 |   int using_step;
 | 
|---|
 | 284 | 
 | 
|---|
 | 285 |   RefSCVector backtrack = -1.0 * backtrack_factor_ * search_direction_;
 | 
|---|
 | 286 |   RefSCVector step = search_direction_.copy();
 | 
|---|
 | 287 |   
 | 
|---|
 | 288 |   // check if line search is needed
 | 
|---|
 | 289 |   if( sufficient_decrease(step) ) {
 | 
|---|
 | 290 |     ExEnv::out0() << endl << indent << 
 | 
|---|
 | 291 |       "Unscaled initial step yields sufficient decrease." << endl;
 | 
|---|
 | 292 |     return 1;
 | 
|---|
 | 293 |   }
 | 
|---|
 | 294 | 
 | 
|---|
 | 295 |   ExEnv::out0() << endl << indent 
 | 
|---|
 | 296 |     << "Unscaled initial step does not yield a sufficient decrease."
 | 
|---|
 | 297 |     << endl << indent
 | 
|---|
 | 298 |     << "Initiating backtracking line search." << endl; 
 | 
|---|
 | 299 | 
 | 
|---|
 | 300 |   // perform a simple backtrack
 | 
|---|
 | 301 |   values.push_back( function()->value() );  
 | 
|---|
 | 302 |   for(int i=0; i<max_iterations_ && !acceptable && descent; ++i) {
 | 
|---|
 | 303 | 
 | 
|---|
 | 304 |     step = step + backtrack;
 | 
|---|
 | 305 | 
 | 
|---|
 | 306 |     if ( sqrt(step.scalar_product(step)) >= 0.1 * 
 | 
|---|
 | 307 |          sqrt(search_direction_.scalar_product(search_direction_)) ) {
 | 
|---|
 | 308 | 
 | 
|---|
 | 309 |       ++took_step;    
 | 
|---|
 | 310 |       if( sufficient_decrease(step) ) {
 | 
|---|
 | 311 |         ExEnv::out0() << endl << indent << "Backtrack " << i+1 
 | 
|---|
 | 312 |                       << " yields a sufficient decrease." << endl;
 | 
|---|
 | 313 |         acceptable = 1; 
 | 
|---|
 | 314 |         using_step = i+1;
 | 
|---|
 | 315 |       }
 | 
|---|
 | 316 |       
 | 
|---|
 | 317 |       else if ( values.back() < function()->value() ) {
 | 
|---|
 | 318 |         ExEnv::out0() << endl << indent << "Backtrack " << i+1 
 | 
|---|
 | 319 |                       << " increases value; terminating search." << endl;
 | 
|---|
 | 320 |         acceptable = 1;
 | 
|---|
 | 321 |         using_step = i;
 | 
|---|
 | 322 |       }
 | 
|---|
 | 323 |       
 | 
|---|
 | 324 |       else {
 | 
|---|
 | 325 |         ExEnv::out0() << endl << indent << "Backtrack " << i+1 
 | 
|---|
 | 326 |                       << " does not yield a sufficient decrease." << endl;
 | 
|---|
 | 327 |         using_step = i+1;
 | 
|---|
 | 328 |       }
 | 
|---|
 | 329 | 
 | 
|---|
 | 330 |       values.push_back( function()->value() );
 | 
|---|
 | 331 |     }
 | 
|---|
 | 332 | 
 | 
|---|
 | 333 |     else { 
 | 
|---|
 | 334 |       ExEnv::out0() << indent << 
 | 
|---|
 | 335 |         "Search direction does not appear to be a descent direction;" <<
 | 
|---|
 | 336 |         " terminating search." << endl;
 | 
|---|
 | 337 |       descent = 0;
 | 
|---|
 | 338 |     }
 | 
|---|
 | 339 |   }
 | 
|---|
 | 340 |              
 | 
|---|
 | 341 | 
 | 
|---|
 | 342 |   if ( !acceptable && descent ) {
 | 
|---|
 | 343 |     ExEnv::out0() << indent << 
 | 
|---|
 | 344 |       "Maximum number of backtrack iterations has been exceeded." << endl;
 | 
|---|
 | 345 |     acceptable = 1;
 | 
|---|
 | 346 |   }
 | 
|---|
 | 347 |  
 | 
|---|
 | 348 |   for(int i=0; i <= took_step; ++i) {
 | 
|---|
 | 349 |     if(i==0) ExEnv::out0() << indent << "initial step    " << " value: ";
 | 
|---|
 | 350 |     else ExEnv::out0() << indent << "backtrack step " << i << " value: ";
 | 
|---|
 | 351 |     ExEnv::out0() << scprintf("%15.10lf", values.front()) << endl;
 | 
|---|
 | 352 |     values.pop_front();
 | 
|---|
 | 353 |   }
 | 
|---|
 | 354 | 
 | 
|---|
 | 355 |   if(descent) { 
 | 
|---|
 | 356 |     ExEnv::out0() << indent << "Using step " << using_step << endl;
 | 
|---|
 | 357 |    
 | 
|---|
 | 358 |     // use next to last step if value went up
 | 
|---|
 | 359 |     if( using_step != took_step ) {
 | 
|---|
 | 360 |       function()->set_x( function()->get_x() - backtrack );
 | 
|---|
 | 361 |       Ref<NonlinearTransform> t = function()->change_coordinates();
 | 
|---|
 | 362 |       apply_transform(t);
 | 
|---|
 | 363 |     }
 | 
|---|
 | 364 |   }
 | 
|---|
 | 365 |   else {
 | 
|---|
 | 366 |     function()->set_x( initial_x_ );
 | 
|---|
 | 367 |     Ref<NonlinearTransform> t = function()->change_coordinates();
 | 
|---|
 | 368 |     apply_transform(t);
 | 
|---|
 | 369 |   }
 | 
|---|
 | 370 | 
 | 
|---|
 | 371 |   // returning 0 only if search direction is not descent direction
 | 
|---|
 | 372 |   return acceptable;
 | 
|---|
 | 373 | }
 | 
|---|
 | 374 | 
 | 
|---|
 | 375 | /////////////////////////////////////////////////////////////////////////////
 | 
|---|
 | 376 | 
 | 
|---|
 | 377 | // Local Variables:
 | 
|---|
 | 378 | // mode: c++
 | 
|---|
 | 379 | // c-file-style: "CLJ"
 | 
|---|
 | 380 | // End:
 | 
|---|