Skip to content

Documentation for mutax

Enabling parallel execution on CPU

By default, JAX on CPU only uses a single core. To enable parallel execution on all available CPU cores, set the jax_num_cpu_devices configuration option appropriately. This should be done at the beginning of your code as follows:

import multiprocessing
import jax

jax.config.update("jax_num_cpu_devices", multiprocessing.cpu_count())

mutax.differential_evolution(func: Callable[[jax.Array], jax.Array], /, bounds: jax.Array, *, key: jax.Array | None = None, strategy: Literal['rand1bin', 'best1bin'] = 'best1bin', maxiter: int = 1000, popsize: int = 15, tol: float = 0.01, atol: float = 0, mutation: float | tuple[float, float] = (0.5, 1.0), recombination: float = 0.8, disp: bool = False, polish: bool = True, updating: Literal['immediate', 'deferred'] = 'immediate', workers: int | Callable[[Callable[[jax.Array], jax.Array], jax.Array], jax.Array] = 1, x0: jax.Array | None = None, vectorized: bool = False) -> OptimizeResults

Find the global minimum of a multivariate function.

Uses the Differential Evolution algorithm to find the global minimum of the given objective function within the specified bounds.

Arguments:

  • func: The objective function to be minimized. It must take a single argument (a 1D array) and return a scalar.
  • bounds: A 2D array specifying the lower and upper bounds for each dimension of the input space.
  • key: A JAX random key for stochastic operations. You can use e.g. jax.random.key(seed) to generate a key. If not given, a default key is used, which may change between runs.
  • strategy: The differential evolution strategy to use. Can be either "rand1bin" or "best1bin". The "rand1bin" strategy uses a randomly selected population member as the base vector, while "best1bin" uses the best population member found so far.
  • maxiter: The maximum number of generations to evolve the population.
  • popsize: Multiplier for setting the total population size. The population size is determined by popsize * dim.
  • tol: Relative tolerance for convergence.
  • atol: Absolute tolerance for convergence.
  • mutation: A float or a tuple of two floats specifying the mutation factor. If a tuple is provided, the mutation factor is sampled uniformly from this range for each mutation.
  • recombination: A float in [0, 1] specifying the recombination probability.
  • disp: Whether to print progress messages at each iteration.
  • polish: Whether to perform a local optimization using BFGS at the end of the evolution process to attempt to refine the best solution found. For this local optimization to be effective, the objective function should be differentiable.
  • updating: Strategy for updating the population. Can be either "immediate" or "deferred". "immediate" updates individuals as soon as a better trial vector is found, while "deferred" updates the population after all trial vectors have been evaluated.
  • workers: Number of JAX devices (CPUs/GPUs/TPUs) used for evaluating the objective function. Uses Parajax for parallelization. If set to -1, uses all available JAX devices. Alternatively, if a callable is provided, it should be a callable as workers(func, x), where x is a 2D array with each row being a different input to be evaluated. The callable should return a 1D array of function values. Setting this argument to a value other than 1 will override updating to "deferred".
  • x0: Optional initial guess.
  • vectorized: If True, indicates that func accepts a 2D array where each column is a different input to be evaluated. If used, it will override updating to "deferred".

Returns:

An OptimizeResults object containing the optimization results.

Reference:

R. Storn and K. Price, “Differential Evolution - A Simple and Efficient Heuristic for global Optimization over Continuous Spaces,” Journal of Global Optimization, vol. 11, no. 4, pp. 341-359, Dec. 1997, doi: 10.1023/a:1008202821328.

Source code in mutax/__init__.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
@eqx.filter_jit
def differential_evolution(  # noqa: C901, PLR0912, PLR0913, PLR0915
    func: Callable[[jax.Array], jax.Array],
    /,
    bounds: jax.Array,
    *,
    key: jax.Array | None = None,
    strategy: Literal["rand1bin", "best1bin"] = "best1bin",
    maxiter: int = 1_000,
    popsize: int = 15,
    tol: float = 0.01,
    atol: float = 0,
    mutation: float | tuple[float, float] = (0.5, 1.0),
    recombination: float = 0.8,
    disp: bool = False,
    polish: bool = True,
    updating: Literal["immediate", "deferred"] = "immediate",
    workers: int
    | Callable[[Callable[[jax.Array], jax.Array], jax.Array], jax.Array] = 1,
    x0: jax.Array | None = None,
    vectorized: bool = False,
) -> OptimizeResults:
    """Find the global minimum of a multivariate function.

    Uses the Differential Evolution algorithm to find the global minimum of the
    given objective function within the specified bounds.

    **Arguments:**

    - `func`: The objective function to be minimized. It must take a single argument
    (a 1D array) and return a scalar.
    - `bounds`: A 2D array specifying the lower and upper bounds for each dimension of
      the input space.
    - `key`: A JAX random key for stochastic operations. You can use e.g.
      `jax.random.key(seed)` to generate a key. If not given, a default key is used,
      which may change between runs.
    - `strategy`: The differential evolution strategy to use. Can be either "rand1bin"
      or "best1bin". The "rand1bin" strategy uses a randomly selected population member
      as the base vector, while "best1bin" uses the best population member found so far.
    - `maxiter`: The maximum number of generations to evolve the population.
    - `popsize`: Multiplier for setting the total population size. The population size
      is determined by `popsize * dim`.
    - `tol`: Relative tolerance for convergence.
    - `atol`: Absolute tolerance for convergence.
    - `mutation`: A float or a tuple of two floats specifying the mutation factor. If a
      tuple is provided, the mutation factor is sampled uniformly from this range for
      each mutation.
    - `recombination`: A float in [0, 1] specifying the recombination probability.
    - `disp`: Whether to print progress messages at each iteration.
    - `polish`: Whether to perform a local optimization using BFGS at the end of the
      evolution process to attempt to refine the best solution found. For this local
      optimization to be effective, the objective function should be differentiable.
    - `updating`: Strategy for updating the population. Can be either "immediate" or
      "deferred". "immediate" updates individuals as soon as a better trial vector is
      found, while "deferred" updates the population after all trial vectors have been
      evaluated.
    - `workers`: Number of JAX devices (CPUs/GPUs/TPUs) used for evaluating the
      objective function. Uses [Parajax](https://github.com/gerlero/parajax) for
      parallelization. If set to -1, uses all available JAX devices. Alternatively, if
      a callable is provided, it should be a callable as ```workers(func, x)```, where
      `x` is a 2D array with each row being a different input to be evaluated. The
      callable should return a 1D array of function values. Setting this argument to a
      value other than 1 will override `updating` to "deferred".
    - `x0`: Optional initial guess.
    - `vectorized`: If `True`, indicates that `func` accepts a 2D array where each
      column is a different input to be evaluated. If used, it will override `updating`
      to "deferred".

    **Returns:**

    An `OptimizeResults` object containing the optimization results.

    **Reference:**

    R. Storn and K. Price, “Differential Evolution - A Simple and Efficient Heuristic
    for global Optimization over Continuous Spaces,” Journal of Global Optimization,
    vol. 11, no. 4, pp. 341-359, Dec. 1997, doi: 10.1023/a:1008202821328.
    """
    dim = len(bounds)
    lower = jnp.array([b[0] for b in bounds])
    upper = jnp.array([b[1] for b in bounds])
    popsize *= dim

    if key is None:
        key = jax.random.key(hash("mutax"))

    if not callable(workers) and workers < 1 and workers != -1:
        msg = "workers must be a positive integer or -1"
        raise ValueError(msg)

    if workers != 1 and updating == "immediate":
        msg = (
            "differential_evolution: the 'workers' keyword has overridden "
            "updating='immediate' to updating='deferred'"
        )
        warnings.warn(msg, UserWarning, stacklevel=2)
        updating = "deferred"
    if workers == 1 and vectorized and updating == "immediate":
        msg = (
            "differential_evolution: the 'vectorized' keyword has overridden "
            "updating='immediate' to updating='deferred'"
        )
        warnings.warn(msg, UserWarning, stacklevel=2)
        updating = "deferred"

    if workers == 1:
        if vectorized:

            def single_func(x: jax.Array) -> jax.Array:
                return func(x[None, :])[0]

            def vmapped_func(x: jax.Array) -> jax.Array:
                return func(x.T)
        else:
            single_func = func
            vmapped_func = jax.vmap(func)
    elif callable(workers):
        if vectorized:
            msg = "If 'workers' is a callable, 'vectorized' must be False"
            raise ValueError(msg)

        def single_func(x: jax.Array) -> jax.Array:
            return func(x[None, :])[0]

        def vmapped_func(x: jax.Array) -> jax.Array:
            return workers(func, x)  # ty: ignore[call-top-callable]
    else:
        max_devices = None if workers == -1 else workers
        if vectorized:

            def single_func(x: jax.Array) -> jax.Array:
                return func(x[None, :])[0]

            vmapped_func = parallelize(lambda x: func(x.T), max_devices=max_devices)
        else:
            single_func = func
            vmapped_func = parallelize(jax.vmap(func), max_devices=max_devices)

    # Initialize population (Latin hypercube sampling)
    segsize = 1.0 / popsize
    key, subkey = jax.random.split(key)
    pop = lower + (upper - lower) * jnp.stack(
        jax.vmap(
            lambda k: (
                segsize * jax.random.uniform(k, (popsize,))
                + jnp.linspace(0.0, 1.0, popsize, endpoint=False)
            )[jax.random.permutation(k, popsize)]
        )(jax.random.split(subkey, dim)),
        axis=1,
    )

    if x0 is not None:
        pop = pop.at[0].set(jnp.asarray(x0))

    fitness = vmapped_func(pop)

    def make_trial(
        pop: jax.Array, fitness: jax.Array, i: int, key: jax.Array
    ) -> tuple[jax.Array, jax.Array]:
        key, subkey = jax.random.split(key)

        if strategy == "best1bin":
            # Use best member as base vector
            best_idx = jnp.argmin(fitness)

            # Select two distinct indices from 0..pop_size-1 excluding i and best_idx
            idxs = jnp.arange(popsize)
            idxs = jnp.where(idxs == i, popsize, idxs)
            idxs = jnp.where(idxs == best_idx, popsize + 1, idxs)
            idx_perm = jax.random.permutation(subkey, idxs)
            r1, r2 = idx_perm[:2]
            r1 = jnp.where(r1 == popsize, idx_perm[2], r1)
            r1 = jnp.where(r1 == popsize + 1, idx_perm[3], r1)
            r2 = jnp.where(r2 == popsize, idx_perm[4], r2)
            r2 = jnp.where(r2 == popsize + 1, idx_perm[5], r2)

            # Mutation
            try:
                mut_lower, mut_upper = mutation  # ty: ignore[not-iterable]
            except TypeError:
                mut_val: float | jax.Array = mutation  # ty: ignore[invalid-assignment]
            else:
                key, subkey = jax.random.split(key)
                mut_val = jax.random.uniform(
                    subkey, (), minval=mut_lower, maxval=mut_upper
                )

            mutant = pop[best_idx] + mut_val * (pop[r1] - pop[r2])

        elif strategy == "rand1bin":
            # Use random member as base vector
            # Select three distinct indices from 0..pop_size-1 excluding i
            idxs = jnp.arange(popsize)
            idxs = jnp.where(idxs == i, popsize, idxs)
            idx_perm = jax.random.permutation(subkey, idxs)
            r1, r2, r3 = idx_perm[:3]
            r1 = jnp.where(r1 == popsize, idx_perm[3], r1)
            r2 = jnp.where(r2 == popsize, idx_perm[4], r2)
            r3 = jnp.where(r3 == popsize, idx_perm[5], r3)

            # Mutation
            try:
                mut_lower, mut_upper = mutation  # ty: ignore[not-iterable]
            except TypeError:
                mut_val: float | jax.Array = mutation  # ty: ignore[invalid-assignment]
            else:
                key, subkey = jax.random.split(key)
                mut_val = jax.random.uniform(
                    subkey, (), minval=mut_lower, maxval=mut_upper
                )

            mutant = pop[r1] + mut_val * (pop[r2] - pop[r3])

        else:
            msg = f"Unrecognized strategy '{strategy}'"
            raise ValueError(msg)

        mutant = jnp.clip(mutant, lower, upper)

        # Crossover
        key, subkey = jax.random.split(key)
        cross_points = jax.random.uniform(subkey, (dim,)) < recombination
        key, subkey = jax.random.split(key)
        cross_points = cross_points.at[jax.random.randint(subkey, (), 0, dim)].set(True)
        trial = jnp.where(cross_points, mutant, pop[i])

        return trial, key

    if updating == "immediate":

        def evolve(
            nit: int, pop: jax.Array, fitness: jax.Array, key: jax.Array
        ) -> tuple[int, jax.Array, jax.Array, jax.Array]:
            if disp:
                jax.debug.print(
                    "differential_evolution step {nit}: f(x)={fmin}",
                    nit=nit,
                    fmin=jnp.min(fitness),
                )

            def evolve_one(
                i: int, carry: tuple[jax.Array, jax.Array, jax.Array]
            ) -> tuple[jax.Array, jax.Array, jax.Array]:
                pop, fitness, key = carry
                trial, key = make_trial(pop, fitness, i, key)

                # Selection
                f_trial = func(trial)
                better = f_trial < fitness[i]
                pop = pop.at[i].set(jnp.where(better, trial, pop[i]))
                fitness = fitness.at[i].set(jnp.where(better, f_trial, fitness[i]))

                return pop, fitness, key

            pop, fitness, key = jax.lax.fori_loop(
                0, popsize, evolve_one, (pop, fitness, key)
            )
            return nit + 1, pop, fitness, key

    elif updating == "deferred":

        def evolve(
            nit: int, pop: jax.Array, fitness: jax.Array, key: jax.Array
        ) -> tuple[int, jax.Array, jax.Array, jax.Array]:
            if disp:
                jax.debug.print(
                    "differential_evolution step {nit}: f(x)={fmin}",
                    nit=nit,
                    fmin=jnp.min(fitness),
                )

            keys = jax.random.split(key, popsize)
            trials, keys = jax.vmap(lambda i, k: make_trial(pop, fitness, i, k))(
                jnp.arange(popsize), keys
            )
            key = keys[-1]
            f_trials = vmapped_func(trials)
            better = f_trials < fitness
            pop = jnp.where(better[:, None], trials, pop)
            fitness = jnp.where(better, f_trials, fitness)
            return nit + 1, pop, fitness, key

    else:
        msg = "updating must be 'immediate' or 'deferred'"
        raise ValueError(msg)

    def converged(fitness: jax.Array) -> jax.Array:
        return jnp.all(jnp.isfinite(fitness)) & (
            jnp.std(fitness) <= atol + tol * jnp.abs(jnp.mean(fitness))
        )

    nit, pop, fitness, key = jax.lax.while_loop(
        lambda val: (val[0] <= maxiter) & (~converged(val[2])),
        lambda val: evolve(*val),
        (1, pop, fitness, key),
    )

    success = converged(fitness)
    best_idx = jnp.argmin(fitness)
    best = pop[best_idx]
    best_fitness = fitness[best_idx]

    if polish:
        if disp:
            jax.debug.print("Polishing solution with BFGS")

        result = jax.scipy.optimize.minimize(
            single_func,
            best,
            method="BFGS",
        )
        polished = (
            result.success
            & (result.fun < best_fitness)
            & jnp.all((result.x >= lower) & (result.x <= upper))
        )
        best = jnp.where(polished, result.x, best)
        best_fitness = jnp.where(polished, result.fun, best_fitness)

        return OptimizeResults(
            x=best,
            fun=best_fitness,
            success=success,
            status=(~success).astype(int),
            jac=jnp.where(polished, result.jac, jnp.full_like(result.jac, jnp.nan)),
            hess_inv=jnp.where(
                polished,
                result.hess_inv,
                jnp.full_like(result.hess_inv, jnp.nan),  # ty: ignore[invalid-argument-type]
            ),  # ty: ignore[invalid-argument-type]
            nfev=result.nfev + nit * popsize,
            njev=result.njev,
            nit=nit,
        )

    return OptimizeResults(
        x=best,
        fun=best_fitness,
        success=success,
        status=(~success).astype(int),
        jac=None,  # ty: ignore[invalid-argument-type]
        hess_inv=None,
        nfev=nit * popsize,
        njev=jnp.array(0),
        nit=nit,
    )

mutax.OptimizeResults = jax.scipy.optimize.OptimizeResults module-attribute

Object holding optimization results.

Attributes:

  • x: final solution.
  • success: whether the optimization succeeded.
  • status: integer solver specific return code. 0 means converged (nominal), 1=max number of iterations reached.
  • fun: final function value.
  • nfev: integer number of function calls used.
  • njev: integer number of Jacobian evaluations used (only if polish was set to True).
  • nit: integer number of iterations of the optimization algorithm.
  • jac: final Jacobian (only if the solution was polished).
  • hess_inv: inverse of the final Hessian (only if the solution was polished).